"examples/vscode:/vscode.git/clone" did not exist on "8c017b34908f8d4a877d862dd21b99aef7057c55"
Unverified Commit 58631d7c authored by nemanjaudovic's avatar nemanjaudovic Committed by GitHub
Browse files

[Bugfix] Fix scaled_mm output narrowing for 3D input tensors (#38093)


Signed-off-by: default avatarnemanjaudovic <nudovic@amd.com>
parent a943839e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import torch
......@@ -13,6 +14,13 @@ from .ScaledMMLinearKernel import (
)
def _get_num_tokens(output_shape: list) -> int:
# torch._scaled_mm works with 2D tensors, so input tensors are
# flattened if they are 3D. If output_shape is 3D, num_tokens is
# the product of all dims except the last (hidden dim).
return math.prod(output_shape[:-1])
class TorchFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
"""
Base class for FP8 linear kernels using Torch.
......@@ -78,7 +86,8 @@ class PerTensorTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape)
num_tokens = _get_num_tokens(output_shape)
return torch.narrow(output, 0, 0, num_tokens).view(*output_shape)
class RowWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
......@@ -145,7 +154,8 @@ class RowWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
bias=bias,
)
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape)
num_tokens = _get_num_tokens(output_shape)
return torch.narrow(output, 0, 0, num_tokens).view(*output_shape)
class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
......@@ -206,8 +216,9 @@ class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, output_shape[0])
x_scale = torch.narrow(As, 0, 0, output_shape[0])
num_tokens = _get_num_tokens(output_shape)
output = torch.narrow(output, 0, 0, num_tokens)
x_scale = torch.narrow(As, 0, 0, num_tokens)
# DQ
# C = sw * sx * (X * W) + bias
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment