Unverified Commit f8449052 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Make grouped weights opt-in (#2678)



* Make grouped weights opt-in
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change varname
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 496620a9
...@@ -585,10 +585,19 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ ...@@ -585,10 +585,19 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("single_param", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"]) @pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4]) @pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear( def test_sanity_grouped_linear(
dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split dtype,
bs,
model,
fp8_recipe,
fp8_model_params,
use_bias,
single_param,
num_gemms,
empty_split,
): ):
if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params: if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
pytest.skip("FP8 model parameters are not supported in debug mode.") pytest.skip("FP8 model parameters are not supported in debug mode.")
...@@ -598,6 +607,9 @@ def test_sanity_grouped_linear( ...@@ -598,6 +607,9 @@ def test_sanity_grouped_linear(
bs = bs * 16 bs = bs * 16
num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
if single_param:
os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"] = "1"
if fp8_recipe is not None: if fp8_recipe is not None:
if not is_fp8_supported(config): if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8") pytest.skip("Model config does not support FP8")
...@@ -617,7 +629,8 @@ def test_sanity_grouped_linear( ...@@ -617,7 +629,8 @@ def test_sanity_grouped_linear(
# Verify that weights are stored in contiguous GroupedTensor storage. # Verify that weights are stored in contiguous GroupedTensor storage.
weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)] weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)]
if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()): if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()):
check_grouped_tensor_pointers(weights, fp8_recipe) if single_param:
check_grouped_tensor_pointers(weights, fp8_recipe)
inp_hidden_states = torch.randn( inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
...@@ -636,6 +649,9 @@ def test_sanity_grouped_linear( ...@@ -636,6 +649,9 @@ def test_sanity_grouped_linear(
loss.backward() loss.backward()
assert out.shape == (num_tokens, ffn_hidden_size) assert out.shape == (num_tokens, ffn_hidden_size)
if single_param:
del os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"]
@pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_recipe", fp8_recipes)
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
from typing import Union, Optional, Callable, Tuple, List from typing import Union, Optional, Callable, Tuple, List
from itertools import chain from itertools import chain
import warnings import warnings
import os
import functools import functools
import torch import torch
...@@ -793,7 +794,9 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -793,7 +794,9 @@ class GroupedLinear(TransformerEngineBaseModule):
def reset_parameters(self, defer_init=False): def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init) super().reset_parameters(defer_init=defer_init)
self.make_grouped_weights(defer_init=defer_init) # Grouped tensor weights is an opt-in feature.
if bool(int(os.getenv("NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS", "0"))):
self.make_grouped_weights(defer_init=defer_init)
def set_tensor_parallel_attributes(self, defer_init=False) -> None: def set_tensor_parallel_attributes(self, defer_init=False) -> None:
"""Set attributes needed for TP""" """Set attributes needed for TP"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment