Unverified Commit 58631d7c authored by nemanjaudovic's avatar nemanjaudovic Committed by GitHub
Browse files

[Bugfix] Fix scaled_mm output narrowing for 3D input tensors (#38093)


Signed-off-by: default avatarnemanjaudovic <nudovic@amd.com>
parent a943839e
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import torch import torch
...@@ -13,6 +14,13 @@ from .ScaledMMLinearKernel import ( ...@@ -13,6 +14,13 @@ from .ScaledMMLinearKernel import (
) )
def _get_num_tokens(output_shape: list) -> int:
# torch._scaled_mm works with 2D tensors, so input tensors are
# flattened if they are 3D. If output_shape is 3D, num_tokens is
# the product of all dims except the last (hidden dim).
return math.prod(output_shape[:-1])
class TorchFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): class TorchFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
""" """
Base class for FP8 linear kernels using Torch. Base class for FP8 linear kernels using Torch.
...@@ -78,7 +86,8 @@ class PerTensorTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel): ...@@ -78,7 +86,8 @@ class PerTensorTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
if type(output) is tuple and len(output) == 2: if type(output) is tuple and len(output) == 2:
output = output[0] output = output[0]
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape) num_tokens = _get_num_tokens(output_shape)
return torch.narrow(output, 0, 0, num_tokens).view(*output_shape)
class RowWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel): class RowWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
...@@ -145,7 +154,8 @@ class RowWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel): ...@@ -145,7 +154,8 @@ class RowWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
bias=bias, bias=bias,
) )
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape) num_tokens = _get_num_tokens(output_shape)
return torch.narrow(output, 0, 0, num_tokens).view(*output_shape)
class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel): class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
...@@ -206,8 +216,9 @@ class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel): ...@@ -206,8 +216,9 @@ class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
output = output[0] output = output[0]
# Unpad (undo num_token_padding) # Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, output_shape[0]) num_tokens = _get_num_tokens(output_shape)
x_scale = torch.narrow(As, 0, 0, output_shape[0]) output = torch.narrow(output, 0, 0, num_tokens)
x_scale = torch.narrow(As, 0, 0, num_tokens)
# DQ # DQ
# C = sw * sx * (X * W) + bias # C = sw * sx * (X * W) + bias
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment