Unverified Commit f8449052 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Make grouped weights opt-in (#2678)



* Make grouped weights opt-in
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change varname
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 496620a9
......@@ -585,10 +585,19 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("single_param", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
dtype,
bs,
model,
fp8_recipe,
fp8_model_params,
use_bias,
single_param,
num_gemms,
empty_split,
):
if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
pytest.skip("FP8 model parameters are not supported in debug mode.")
......@@ -598,6 +607,9 @@ def test_sanity_grouped_linear(
bs = bs * 16
num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
if single_param:
os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"] = "1"
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
......@@ -617,7 +629,8 @@ def test_sanity_grouped_linear(
# Verify that weights are stored in contiguous GroupedTensor storage.
weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)]
if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()):
check_grouped_tensor_pointers(weights, fp8_recipe)
if single_param:
check_grouped_tensor_pointers(weights, fp8_recipe)
inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
......@@ -636,6 +649,9 @@ def test_sanity_grouped_linear(
loss.backward()
assert out.shape == (num_tokens, ffn_hidden_size)
if single_param:
del os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"]
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
......
......@@ -6,6 +6,7 @@
from typing import Union, Optional, Callable, Tuple, List
from itertools import chain
import warnings
import os
import functools
import torch
......@@ -793,7 +794,9 @@ class GroupedLinear(TransformerEngineBaseModule):
def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
self.make_grouped_weights(defer_init=defer_init)
# Grouped tensor weights is an opt-in feature.
if bool(int(os.getenv("NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS", "0"))):
self.make_grouped_weights(defer_init=defer_init)
def set_tensor_parallel_attributes(self, defer_init=False) -> None:
"""Set attributes needed for TP"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment