Unverified Commit f76e85c2 authored by Alexander Matveev's avatar Alexander Matveev Committed by GitHub
Browse files

[Performance][Hopper] Avoid M dim padding to 4x for most cases (due to cuda...


[Performance][Hopper] Avoid M dim padding to 4x for most cases (due to cuda graphs paddings) (#28492)
Signed-off-by: default avatarAlexander Matveev <amatveev@redhat.com>
parent 54aecd9e
...@@ -115,6 +115,9 @@ def _padded_cutlass( ...@@ -115,6 +115,9 @@ def _padded_cutlass(
dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple) dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple)
) )
has_pad = padded > dim
if has_pad:
padded_shape = [padded, *qx.shape[1:]] padded_shape = [padded, *qx.shape[1:]]
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype) padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
padded_qx[0 : qx.shape[0], ...].copy_(qx) padded_qx[0 : qx.shape[0], ...].copy_(qx)
...@@ -129,6 +132,10 @@ def _padded_cutlass( ...@@ -129,6 +132,10 @@ def _padded_cutlass(
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
) )
return output[0 : qx.shape[0], ...] return output[0 : qx.shape[0], ...]
else:
return cutlass_scaled_mm(
qx, weight, x_scale, weight_scale, block_size, output_dtype
)
def _padded_cutlass_fake( def _padded_cutlass_fake(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment