Unverified Commit b215116a authored by kwyss-nvidia's avatar kwyss-nvidia Committed by GitHub
Browse files

Check calling convention for amax switch. (#2506)



* Check calling convention for amax switch.

Wgrad gemms with colwise x colwise require
rowwise data via general_gemm. Since dy
has both for dgrad and wgrad, the brittleness
has likely not affected results.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Clear rowwise data when applicable.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update test with columnwise cases.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Check enum value rather than implicit cast.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

---------
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
parent 36f2dfd2
...@@ -122,8 +122,15 @@ def check_nvfp4_gemm_versus_reference( ...@@ -122,8 +122,15 @@ def check_nvfp4_gemm_versus_reference(
) )
# Create reference quantized tensors needed by reference GEMM # Create reference quantized tensors needed by reference GEMM
x_nvfp4_ref = ref_quantizer.quantize(x) # Reference GEMM is only rowwise.
w_nvfp4_ref = ref_quantizer.quantize(w) if x_columnwise:
x_nvfp4_ref = ref_quantizer.quantize(x.t().contiguous())
else:
x_nvfp4_ref = ref_quantizer.quantize(x)
if w_columnwise:
w_nvfp4_ref = ref_quantizer.quantize(w.t().contiguous())
else:
w_nvfp4_ref = ref_quantizer.quantize(w)
# Reference GEMM using quantizer's qgemm method # Reference GEMM using quantizer's qgemm method
y_ref = ref_quantizer.qgemm( y_ref = ref_quantizer.qgemm(
...@@ -155,6 +162,10 @@ def check_nvfp4_gemm_versus_reference( ...@@ -155,6 +162,10 @@ def check_nvfp4_gemm_versus_reference(
use_grad = False use_grad = False
use_split_accumulator = False use_split_accumulator = False
if x_columnwise:
x_nvfp4_native.update_usage(rowwise_usage=False)
if w_columnwise:
w_nvfp4_native.update_usage(rowwise_usage=False)
# Native cuBLAS GEMM # Native cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output # return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out. # We are just capturing out.
...@@ -212,11 +223,11 @@ def check_nvfp4_gemm_versus_reference( ...@@ -212,11 +223,11 @@ def check_nvfp4_gemm_versus_reference(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"is_x_columnwise, is_w_columnwise", "is_x_columnwise, is_w_columnwise",
[ [
(False, False), # Only rowwise x rowwise is supported by reference GEMM (False, False), # TN
# Note: Reference GEMM expects inputs as (M,K) x (N,K) with rowwise quantization (True, False), # NN
# Columnwise layouts are not supported by the reference implementation (True, True), # NT
], ],
ids=["rowxrow"], ids=["rowxrow", "colxrow", "colxcol"],
) )
def test_nvfp4_gemm_versus_reference( def test_nvfp4_gemm_versus_reference(
M: int, M: int,
......
...@@ -363,7 +363,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -363,7 +363,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// TODO: Check whether scales are on CPU/GPU or add API to control. // TODO: Check whether scales are on CPU/GPU or add API to control.
// Currently scales are assumed to be on CPU when amax is provided // Currently scales are assumed to be on CPU when amax is provided
// and on GPU when not provided, but this is brittle. // and on GPU when not provided, but this is brittle.
if (use_fp4 && (inputA->amax.dptr != nullptr || inputB->amax.dptr != nullptr)) { if (use_fp4 &&
((transa == CUBLAS_OP_T ? inputA->amax.dptr : inputA->columnwise_amax.dptr) != nullptr ||
(transb == CUBLAS_OP_T ? inputB->columnwise_amax.dptr : inputB->amax.dptr) != nullptr)) {
// Reserve some workspace for alpha scale // Reserve some workspace for alpha scale
NVTE_CHECK(workspaceSize >= 4, NVTE_CHECK(workspaceSize >= 4,
"NVFP4 GEMM requires at least 4 byte workspace for alpha scale, but only has ", "NVFP4 GEMM requires at least 4 byte workspace for alpha scale, but only has ",
...@@ -378,8 +380,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -378,8 +380,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// tensor scales in matmul output, instead of in matmul inputs. // tensor scales in matmul output, instead of in matmul inputs.
float old_alpha = *reinterpret_cast<const float *>(alpha); // Assumed to be on CPU float old_alpha = *reinterpret_cast<const float *>(alpha); // Assumed to be on CPU
TensorWrapper new_alpha_tensor(new_alpha_ptr, std::vector<size_t>{1}, DType::kFloat32); TensorWrapper new_alpha_tensor(new_alpha_ptr, std::vector<size_t>{1}, DType::kFloat32);
nvte_nvfp4_compute_per_tensor_scale(inputA->nvte_tensor, transa, inputB->nvte_tensor, !transb, bool a_rowwise_amax = transa == CUBLAS_OP_T;
old_alpha, new_alpha_tensor.data(), stream); bool b_rowwise_amax = transb != CUBLAS_OP_T;
nvte_nvfp4_compute_per_tensor_scale(inputA->nvte_tensor, a_rowwise_amax, inputB->nvte_tensor,
b_rowwise_amax, old_alpha, new_alpha_tensor.data(), stream);
alpha = new_alpha_ptr; alpha = new_alpha_ptr;
// Make sure beta scale is on device // Make sure beta scale is on device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment