Unverified Commit 6a515304 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix] Fix 3D input passed into cutlass_scaled_mm (#22278)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 35509fc5
...@@ -710,23 +710,25 @@ def cutlass_scaled_mm(a: torch.Tensor, ...@@ -710,23 +710,25 @@ def cutlass_scaled_mm(a: torch.Tensor,
scale_b.shape * [128, 128] == b.shape scale_b.shape * [128, 128] == b.shape
""" """
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == b.shape[ assert bias is None or bias.numel(
1] and bias.dtype == out_dtype ) == b.shape[1] and bias.dtype == out_dtype
m = a.shape[0] # Massage the input to be 2D
n = b.shape[1] target_shape = (*a.shape[:-1], b.shape[1])
a = a.view(-1, a.shape[-1])
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
if current_platform.is_rocm() or not cutlass_compatible_b: if current_platform.is_rocm() or not cutlass_compatible_b:
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa
triton_scaled_mm) triton_scaled_mm)
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
else:
out = torch.empty((m, n), dtype=out_dtype, device=a.device) out = torch.empty((a.shape[0], b.shape[1]),
dtype=out_dtype,
device=a.device)
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
return out return out.view(*target_shape)
def cutlass_scaled_mm_azp(a: torch.Tensor, def cutlass_scaled_mm_azp(a: torch.Tensor,
...@@ -746,15 +748,18 @@ def cutlass_scaled_mm_azp(a: torch.Tensor, ...@@ -746,15 +748,18 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.numel( assert bias is None or bias.numel(
) == b.shape[1] and bias.dtype == out_dtype ) == b.shape[1] and bias.dtype == out_dtype
assert azp is None or azp.numel() == a.shape[0]
m = a.shape[0] # Massage the input to be 2D
n = b.shape[1] target_shape = (*a.shape[:-1], b.shape[1])
out = torch.empty((m, n), dtype=out_dtype, device=a.device) a = a.view(-1, a.shape[-1])
assert azp is None or azp.numel() == a.shape[0]
out = torch.empty((a.shape[0], b.shape[1]),
dtype=out_dtype,
device=a.device)
torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj,
azp, bias) azp, bias)
return out return out.view(*target_shape)
def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment