Unverified Commit 97100139 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Fix issue when last input in GroupedLinear is empty.



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* test
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* more sensitive tests
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* typo fix and skip test on blackwell fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent e1c4f51e
...@@ -25,6 +25,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -25,6 +25,7 @@ from transformer_engine.pytorch.utils import (
from transformer_engine.pytorch import ( from transformer_engine.pytorch import (
LayerNormLinear, LayerNormLinear,
Linear, Linear,
GroupedLinear,
LayerNormMLP, LayerNormMLP,
TransformerLayer, TransformerLayer,
RMSNorm, RMSNorm,
...@@ -532,6 +533,55 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ ...@@ -532,6 +533,55 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
assert out.shape == (num_tokens, ffn_hidden_size) assert out.shape == (num_tokens, ffn_hidden_size)
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes_with_zero)
@pytest.mark.parametrize("model", ["small", "weird"])
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
):
config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size
# Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
bs = bs * 16
num_tokens = bs * config.seq_len * (num_gemms - 1)
if fp8_recipe is not None:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_recipe.mxfp8():
pytest.skip("Grouped linear does not support MXFP8")
if not config.is_fp8_supported():
pytest.skip("Model config does not support FP8")
use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
te_grouped_linear = GroupedLinear(
num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
).cuda()
inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
m_splits = [bs * config.seq_len] * num_gemms
if empty_split == "first":
m_splits[0] = 0
elif empty_split == "last":
m_splits[-1] = 0
elif empty_split == "middle":
m_splits[num_gemms // 2] = 0
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
out = te_grouped_linear(inp_hidden_states, m_splits)
loss = out.sum()
loss.backward()
assert out.shape == (num_tokens, ffn_hidden_size)
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("model", ["small", "weird"])
......
...@@ -329,9 +329,13 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -329,9 +329,13 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
at::Tensor out_tensor; at::Tensor out_tensor;
auto size_t_shape = auto size_t_shape =
pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb); pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb);
bool D_numel_is_zero = false;
std::vector<int64_t> D_shape; std::vector<int64_t> D_shape;
for (size_t t : size_t_shape) { for (size_t t : size_t_shape) {
D_shape.push_back(t); D_shape.push_back(t);
if (t == 0) {
D_numel_is_zero = true;
}
} }
auto dtype = GetATenDType(D_type); auto dtype = GetATenDType(D_type);
auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA);
...@@ -339,8 +343,13 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm( ...@@ -339,8 +343,13 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
if (output_data_ptr == nullptr) { if (output_data_ptr == nullptr) {
out_tensor = at::empty(D_shape, opts); out_tensor = at::empty(D_shape, opts);
} else { } else {
// We need to check !D_numel_is_zero because if the final input portion has zero elements,
// output_data_ptr would point beyond the allocated memory of D. This would cause
// at::from_blob to fail as it would reference memory not allocated by CUDA.
if (!D_numel_is_zero) {
out_tensor = at::from_blob(output_data_ptr, D_shape, opts); out_tensor = at::from_blob(output_data_ptr, D_shape, opts);
} }
}
char* char_ptr = reinterpret_cast<char*>(output_data_ptr); char* char_ptr = reinterpret_cast<char*>(output_data_ptr);
char_ptr += D_shape[0] * D_shape[1] * (*D)[0].element_size(); char_ptr += D_shape[0] * D_shape[1] * (*D)[0].element_size();
output_data_ptr = reinterpret_cast<void*>(char_ptr); output_data_ptr = reinterpret_cast<void*>(char_ptr);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment