Unverified Commit f0d22ca1 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

Fix a bug for D being nullptr in grouped gemm (#1475)



* fix a bug for at::from_blob with nullptr
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix a bug for non-TN
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ee4a17de
...@@ -2131,21 +2131,30 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): ...@@ -2131,21 +2131,30 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
if layout == "TN": if layout == "TN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input
out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output out = [torch.randn(m, n, dtype=dtype, device="cuda")] # output
out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
grad = False grad = False
single_output = True
elif layout == "NN": elif layout == "NN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output B = list(
out = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # dgrad torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
) # grad_output
out = [torch.randn(m, k, dtype=dtype, device="cuda")] # dgrad
out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
grad = True grad = True
single_output = True
else: # layout == "NT" else: # layout == "NT"
A = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input
B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output B = list(
torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
) # grad_output
out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad
out_ref = [o.clone() for o in out]
grad = True grad = True
single_output = False
out_ref = [o.clone() for o in out]
for i in range(z): for i in range(z):
general_gemm( general_gemm(
A[i], A[i],
...@@ -2157,17 +2166,20 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): ...@@ -2157,17 +2166,20 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
layout=layout, layout=layout,
out=out_ref[i], out=out_ref[i],
) )
if single_output:
out_ref = [torch.cat(out_ref)]
general_grouped_gemm( general_grouped_gemm(
A, A,
list(B), B,
list(out), out,
dtype, dtype,
get_multi_stream_cublas_workspace(), get_multi_stream_cublas_workspace(),
m_splits=[k] * n, # TODO, not sure m_splits=m_splits,
grad=grad, grad=grad,
accumulate=accumulate, accumulate=accumulate,
layout=layout, layout=layout,
single_output=single_output,
) )
# should be bit-wise match # should be bit-wise match
...@@ -2190,7 +2202,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): ...@@ -2190,7 +2202,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
z, m, k, n = shape z, m, k, n = shape
m_splits = m // z m_splits = [m // z] * z
dtype = torch.bfloat16 dtype = torch.bfloat16
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
...@@ -2242,7 +2254,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): ...@@ -2242,7 +2254,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
out, out,
dtype, dtype,
get_multi_stream_cublas_workspace(), get_multi_stream_cublas_workspace(),
m_splits=[k] * m_splits, m_splits=m_splits,
accumulate=accumulate, accumulate=accumulate,
) )
......
...@@ -336,9 +336,13 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -336,9 +336,13 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
auto dtype = GetATenDType(D_type); auto dtype = GetATenDType(D_type);
auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA);
if (single_output) { if (single_output) {
if (output_data_ptr == nullptr) {
out_tensor = at::empty(D_shape, opts);
} else {
out_tensor = at::from_blob(output_data_ptr, D_shape, opts); out_tensor = at::from_blob(output_data_ptr, D_shape, opts);
}
char* char_ptr = reinterpret_cast<char*>(output_data_ptr); char* char_ptr = reinterpret_cast<char*>(output_data_ptr);
char_ptr += m_splits[i] * te_A.size(0) * (*D)[0].element_size(); char_ptr += D_shape[0] * D_shape[1] * (*D)[0].element_size();
output_data_ptr = reinterpret_cast<void*>(char_ptr); output_data_ptr = reinterpret_cast<void*>(char_ptr);
D_vectors.emplace_back(out_tensor); D_vectors.emplace_back(out_tensor);
} else { } else {
......
...@@ -269,9 +269,10 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -269,9 +269,10 @@ class _GroupedLinear(torch.autograd.Function):
general_grouped_gemm( general_grouped_gemm(
weights, weights,
grad_output, grad_output,
torch.split(dgrad, ctx.m_splits), [dgrad],
ctx.activation_dtype, ctx.activation_dtype,
get_multi_stream_cublas_workspace(), get_multi_stream_cublas_workspace(),
single_output=True,
layout="NN", layout="NN",
m_splits=ctx.m_splits, m_splits=ctx.m_splits,
grad=True, grad=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment