Unverified Commit b215116a authored by kwyss-nvidia's avatar kwyss-nvidia Committed by GitHub
Browse files

Check calling convention for amax switch. (#2506)



* Check calling convention for amax switch.

Wgrad gemms with colwise x colwise require
rowwise data via general_gemm. Since dy
has both for dgrad and wgrad, the brittleness
has likely not affected results.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Clear rowwise data when applicable.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update test with columnwise cases.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Check enum value rather than implicit cast.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

---------
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
parent 36f2dfd2
......@@ -122,7 +122,14 @@ def check_nvfp4_gemm_versus_reference(
)
# Create reference quantized tensors needed by reference GEMM
# Reference GEMM is only rowwise.
if x_columnwise:
x_nvfp4_ref = ref_quantizer.quantize(x.t().contiguous())
else:
x_nvfp4_ref = ref_quantizer.quantize(x)
if w_columnwise:
w_nvfp4_ref = ref_quantizer.quantize(w.t().contiguous())
else:
w_nvfp4_ref = ref_quantizer.quantize(w)
# Reference GEMM using quantizer's qgemm method
......@@ -155,6 +162,10 @@ def check_nvfp4_gemm_versus_reference(
use_grad = False
use_split_accumulator = False
if x_columnwise:
x_nvfp4_native.update_usage(rowwise_usage=False)
if w_columnwise:
w_nvfp4_native.update_usage(rowwise_usage=False)
# Native cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
......@@ -212,11 +223,11 @@ def check_nvfp4_gemm_versus_reference(
@pytest.mark.parametrize(
"is_x_columnwise, is_w_columnwise",
[
(False, False), # Only rowwise x rowwise is supported by reference GEMM
# Note: Reference GEMM expects inputs as (M,K) x (N,K) with rowwise quantization
# Columnwise layouts are not supported by the reference implementation
(False, False), # TN
(True, False), # NN
(True, True), # NT
],
ids=["rowxrow"],
ids=["rowxrow", "colxrow", "colxcol"],
)
def test_nvfp4_gemm_versus_reference(
M: int,
......
......@@ -363,7 +363,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// TODO: Check whether scales are on CPU/GPU or add API to control.
// Currently scales are assumed to be on CPU when amax is provided
// and on GPU when not provided, but this is brittle.
if (use_fp4 && (inputA->amax.dptr != nullptr || inputB->amax.dptr != nullptr)) {
if (use_fp4 &&
((transa == CUBLAS_OP_T ? inputA->amax.dptr : inputA->columnwise_amax.dptr) != nullptr ||
(transb == CUBLAS_OP_T ? inputB->columnwise_amax.dptr : inputB->amax.dptr) != nullptr)) {
// Reserve some workspace for alpha scale
NVTE_CHECK(workspaceSize >= 4,
"NVFP4 GEMM requires at least 4 byte workspace for alpha scale, but only has ",
......@@ -378,8 +380,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// tensor scales in matmul output, instead of in matmul inputs.
float old_alpha = *reinterpret_cast<const float *>(alpha); // Assumed to be on CPU
TensorWrapper new_alpha_tensor(new_alpha_ptr, std::vector<size_t>{1}, DType::kFloat32);
nvte_nvfp4_compute_per_tensor_scale(inputA->nvte_tensor, transa, inputB->nvte_tensor, !transb,
old_alpha, new_alpha_tensor.data(), stream);
bool a_rowwise_amax = transa == CUBLAS_OP_T;
bool b_rowwise_amax = transb != CUBLAS_OP_T;
nvte_nvfp4_compute_per_tensor_scale(inputA->nvte_tensor, a_rowwise_amax, inputB->nvte_tensor,
b_rowwise_amax, old_alpha, new_alpha_tensor.data(), stream);
alpha = new_alpha_ptr;
// Make sure beta scale is on device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment