Unverified Commit f0d22ca1 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

Fix a bug for D being nullptr in grouped gemm (#1475)



* fix a bug for at::from_blob with nullptr
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix a bug for non-TN
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent ee4a17de
......@@ -2131,21 +2131,30 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
if layout == "TN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input
out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output
B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input
out = [torch.randn(m, n, dtype=dtype, device="cuda")] # output
out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
grad = False
single_output = True
elif layout == "NN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output
out = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # dgrad
B = list(
torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
) # grad_output
out = [torch.randn(m, k, dtype=dtype, device="cuda")] # dgrad
out_ref = [o.clone() for o in torch.split(out[0], m_splits)]
grad = True
single_output = True
else: # layout == "NT"
A = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input
B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output
A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input
B = list(
torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits)
) # grad_output
out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad
out_ref = [o.clone() for o in out]
grad = True
single_output = False
out_ref = [o.clone() for o in out]
for i in range(z):
general_gemm(
A[i],
......@@ -2157,17 +2166,20 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
layout=layout,
out=out_ref[i],
)
if single_output:
out_ref = [torch.cat(out_ref)]
general_grouped_gemm(
A,
list(B),
list(out),
B,
out,
dtype,
get_multi_stream_cublas_workspace(),
m_splits=[k] * n, # TODO, not sure
m_splits=m_splits,
grad=grad,
accumulate=accumulate,
layout=layout,
single_output=single_output,
)
# should be bit-wise match
......@@ -2190,7 +2202,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
pytest.skip(reason_for_no_fp8)
z, m, k, n = shape
m_splits = m // z
m_splits = [m // z] * z
dtype = torch.bfloat16
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
......@@ -2242,7 +2254,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate):
out,
dtype,
get_multi_stream_cublas_workspace(),
m_splits=[k] * m_splits,
m_splits=m_splits,
accumulate=accumulate,
)
......
......@@ -336,9 +336,13 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
auto dtype = GetATenDType(D_type);
auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA);
if (single_output) {
if (output_data_ptr == nullptr) {
out_tensor = at::empty(D_shape, opts);
} else {
out_tensor = at::from_blob(output_data_ptr, D_shape, opts);
}
char* char_ptr = reinterpret_cast<char*>(output_data_ptr);
char_ptr += m_splits[i] * te_A.size(0) * (*D)[0].element_size();
char_ptr += D_shape[0] * D_shape[1] * (*D)[0].element_size();
output_data_ptr = reinterpret_cast<void*>(char_ptr);
D_vectors.emplace_back(out_tensor);
} else {
......
......@@ -269,9 +269,10 @@ class _GroupedLinear(torch.autograd.Function):
general_grouped_gemm(
weights,
grad_output,
torch.split(dgrad, ctx.m_splits),
[dgrad],
ctx.activation_dtype,
get_multi_stream_cublas_workspace(),
single_output=True,
layout="NN",
m_splits=ctx.m_splits,
grad=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment