"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "ac5cf86aa6aebbf9e42df51f7e377fbee85bc703"
Unverified Commit 3c04c417 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Fix for deferred init bug causing NeMo MLPerf LLM crash (#619)



* added missing parameter materialization on real device for LayerNorm and RMSNorm
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added new unittest for deferred initialization and modified parameter materialization to support standalone execution outside of FSDP
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* restored tensor parallel attributes that were being wiped out by the parameter reset
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed incorrect order of fp8 metadata initialization
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added deferred init unittest to the QA script
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent 178f1365
...@@ -8,6 +8,7 @@ set -e ...@@ -8,6 +8,7 @@ set -e
pip install pytest==6.2.5 onnxruntime==1.13.1 pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import torch.distributed as dist
import transformer_engine.pytorch as te
_core_modules = [
te.LayerNorm,
te.RMSNorm,
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
]
_composed_modules = [
te.MultiheadAttention,
te.TransformerLayer,
]
batch_size = 32
seq_length = 2048
num_heads = 16
head_dim = 64
dtype = torch.bfloat16
class TestDeferredInit:
@staticmethod
def get_module_args(module):
hidden_size = num_heads * head_dim
args = (hidden_size,)
kwargs = {
'params_dtype': dtype,
'device': 'meta'
}
if module in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
ffn_hidden_size = 2 * hidden_size
args += (ffn_hidden_size, )
kwargs['bias'] = True
if module == te.LayerNormMLP:
kwargs['seq_length'] = seq_length
elif module == te.MultiheadAttention:
args += (num_heads, )
kwargs['fuse_qkv_params'] = True
elif module == te.TransformerLayer:
args += (3 * hidden_size, num_heads)
kwargs['fuse_qkv_params'] = True
kwargs['seq_length'] = seq_length
return args, kwargs
@pytest.mark.parametrize("module_type", _core_modules+_composed_modules)
def test_zero_memory_init(
self,
module_type: torch.nn.Module,
) -> None:
"""Test deferred initialization via device='meta'."""
# This should not allocate any memory on CUDA device until we call reset_parameters() later.
args, kwargs = TestDeferredInit.get_module_args(module_type)
module = module_type(*args, **kwargs)
assert torch.cuda.memory_allocated(device=0) == 0.0, (
f"Initializing {module_type.__name__} with device='meta' prematurely allocated "
"memory on CUDA device"
)
del module
@pytest.mark.parametrize("module_type", _core_modules)
def test_reset_parameters(
self,
module_type: torch.nn.Module,
) -> None:
"""Test parameter reset for core modules that have been initialized with device='meta'."""
# Core modules own their own parameters so calling reset_parameters() here should
# materialize them on CUDA device.
args, kwargs = TestDeferredInit.get_module_args(module_type)
module = module_type(*args, **kwargs)
with torch.no_grad():
module.reset_parameters()
assert torch.cuda.memory_allocated(device=0) > 0.0, (
f"{module_type.__name__}.reset_parameters() failed to materialize parameters "
"on CUDA device"
)
del module
...@@ -769,7 +769,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -769,7 +769,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
for name, param in self.named_parameters(recurse=False): for name, param in self.named_parameters(recurse=False):
# Ensure parameter is on a real device # Ensure parameter is on a real device
if param.device == torch.device('meta'): if param.device == torch.device('meta'):
param = param.to(device='cuda') param = torch.empty_like(param, device='cuda')
# Initialize the parameter values on device # Initialize the parameter values on device
init_fn = self.param_init_meta[name].init_fn init_fn = self.param_init_meta[name].init_fn
......
...@@ -138,8 +138,7 @@ class LayerNorm(torch.nn.Module): ...@@ -138,8 +138,7 @@ class LayerNorm(torch.nn.Module):
dtype=params_dtype, dtype=params_dtype,
) )
) )
setattr(self.weight, "sequence_parallel", sequence_parallel) self.sequence_parallel = sequence_parallel
setattr(self.bias, "sequence_parallel", sequence_parallel)
self.reset_parameters(defer_init=(device == 'meta')) self.reset_parameters(defer_init=(device == 'meta'))
...@@ -168,7 +167,15 @@ class LayerNorm(torch.nn.Module): ...@@ -168,7 +167,15 @@ class LayerNorm(torch.nn.Module):
"""Init LayerNorm parameters""" """Init LayerNorm parameters"""
if defer_init: if defer_init:
return return
if self.weight.device == torch.device('meta'):
self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device='cuda'))
setattr(self.weight, "sequence_parallel", self.sequence_parallel)
init.constant_(self.weight, float(not self.zero_centered_gamma)) init.constant_(self.weight, float(not self.zero_centered_gamma))
if self.bias.device == torch.device('meta'):
self.bias = torch.nn.Parameter(torch.empty_like(self.bias, device='cuda'))
setattr(self.bias, "sequence_parallel", self.sequence_parallel)
init.zeros_(self.bias) init.zeros_(self.bias)
@no_torch_dynamo() @no_torch_dynamo()
......
...@@ -776,14 +776,12 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -776,14 +776,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
) )
self.register_parameter('layer_norm_weight', layer_norm_weight, self.register_parameter('layer_norm_weight', layer_norm_weight,
init_fn=init_method_constant(float(not self.zero_centered_gamma))) init_fn=init_method_constant(float(not self.zero_centered_gamma)))
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition
if self.normalization != "RMSNorm": if self.normalization != "RMSNorm":
layer_norm_bias = torch.nn.Parameter( layer_norm_bias = torch.nn.Parameter(
torch.empty(in_features, device=device, dtype=params_dtype) torch.empty(in_features, device=device, dtype=params_dtype)
) )
self.register_parameter('layer_norm_bias', layer_norm_bias, self.register_parameter('layer_norm_bias', layer_norm_bias,
init_fn=init_method_constant(0.0)) init_fn=init_method_constant(0.0))
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition
else: else:
self.layer_norm_bias = None self.layer_norm_bias = None
...@@ -876,22 +874,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -876,22 +874,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
bias = torch.nn.Parameter(bias) bias = torch.nn.Parameter(bias)
self.register_parameter(self.bias_names[i], bias, self.register_parameter(self.bias_names[i], bias,
init_fn=init_method_constant(0.0)) init_fn=init_method_constant(0.0))
if parallel_mode == "row":
bias.sequence_parallel = sequence_parallel
else: else:
bias = torch.Tensor().to(dtype=params_dtype, device=device) bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, self.bias_names[i], bias) setattr(self, self.bias_names[i], bias)
# Configure tensor parallelism
set_tensor_model_parallel_attributes(
tensor=weight,
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
if parallel_mode == "column":
set_tensor_model_parallel_attributes(bias, True, 0, 1)
# Concatenated tensors are not needed if not splitting # Concatenated tensors are not needed if not splitting
# into multiple parameters # into multiple parameters
if not is_subview: if not is_subview:
...@@ -935,6 +921,33 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -935,6 +921,33 @@ class LayerNormLinear(TransformerEngineBaseModule):
if self.layer_norm_bias is not None: if self.layer_norm_bias is not None:
init.zeros_(self.layer_norm_bias) init.zeros_(self.layer_norm_bias)
def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
if not defer_init:
# Set parallelism attributes for layer norm parameters
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
if self.normalization != "RMSNorm":
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
# Set parallelism attributes for linear weights
for weight in self.weight_names:
set_tensor_model_parallel_attributes(
tensor=getattr(self, weight),
is_parallel=True,
dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
# Set parallelism attributes for linear biases
if self.use_bias:
for bias in self.bias_names:
if self.parallel_mode == "row":
setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel)
elif self.parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)
def get_fp8_weights_scratchpad( def get_fp8_weights_scratchpad(
self, self,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
......
...@@ -1208,14 +1208,12 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1208,14 +1208,12 @@ class LayerNormMLP(TransformerEngineBaseModule):
) )
self.register_parameter('layer_norm_weight', layer_norm_weight, self.register_parameter('layer_norm_weight', layer_norm_weight,
init_fn=init_method_constant(float(not self.zero_centered_gamma))) init_fn=init_method_constant(float(not self.zero_centered_gamma)))
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
if self.normalization != "RMSNorm": if self.normalization != "RMSNorm":
layer_norm_bias = Parameter( layer_norm_bias = Parameter(
torch.empty(hidden_size, device=device, dtype=params_dtype) torch.empty(hidden_size, device=device, dtype=params_dtype)
) )
self.register_parameter('layer_norm_bias', layer_norm_bias, self.register_parameter('layer_norm_bias', layer_norm_bias,
init_fn=init_method_constant(0.0)) init_fn=init_method_constant(0.0))
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition
else: else:
self.layer_norm_bias = None self.layer_norm_bias = None
...@@ -1234,7 +1232,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1234,7 +1232,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
init_fn=init_method, init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker, get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT) fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT)
set_tensor_model_parallel_attributes(self.fc1_weight, True, 0, 1)
self.fp8_weight_shapes.append(self.fc1_weight.shape) self.fp8_weight_shapes.append(self.fc1_weight.shape)
if self.use_bias: if self.use_bias:
...@@ -1243,7 +1240,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1243,7 +1240,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
) )
self.register_parameter('fc1_bias', fc1_bias, self.register_parameter('fc1_bias', fc1_bias,
init_fn=init_method_constant(0.0)) init_fn=init_method_constant(0.0))
set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) # pylint: disable=access-member-before-definition
else: else:
self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device) self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device)
...@@ -1255,7 +1251,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1255,7 +1251,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
init_fn=output_layer_init_method, init_fn=output_layer_init_method,
get_rng_state_tracker=get_rng_state_tracker, get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT) fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT)
set_tensor_model_parallel_attributes(self.fc2_weight, True, 1, 1)
self.fp8_weight_shapes.append(self.fc2_weight.shape) self.fp8_weight_shapes.append(self.fc2_weight.shape)
if self.use_bias: if self.use_bias:
...@@ -1264,9 +1259,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1264,9 +1259,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
) )
self.register_parameter('fc2_bias', fc2_bias, self.register_parameter('fc2_bias', fc2_bias,
init_fn=init_method_constant(0.0)) init_fn=init_method_constant(0.0))
# RPL
if self.set_parallel_mode:
setattr(self.fc2_bias, "sequence_parallel", sequence_parallel) # pylint: disable=access-member-before-definition
else: else:
self.fc2_bias = torch.Tensor().to(dtype=params_dtype, device=device) self.fc2_bias = torch.Tensor().to(dtype=params_dtype, device=device)
...@@ -1312,6 +1304,23 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1312,6 +1304,23 @@ class LayerNormMLP(TransformerEngineBaseModule):
if self.layer_norm_bias is not None: if self.layer_norm_bias is not None:
init.zeros_(self.layer_norm_bias) init.zeros_(self.layer_norm_bias)
def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
if not defer_init:
# Set parallel attributes for layer norm parameters
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
if self.normalization != "RMSNorm":
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
# Set parallel attributes for linear parameters
set_tensor_model_parallel_attributes(self.fc1_weight, True, 0, 1)
set_tensor_model_parallel_attributes(self.fc2_weight, True, 1, 1)
if self.use_bias:
set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1)
if self.set_parallel_mode:
setattr(self.fc2_bias, "sequence_parallel", self.sequence_parallel)
def get_fp8_weights_scratchpad( def get_fp8_weights_scratchpad(
self, self,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
......
...@@ -767,22 +767,10 @@ class Linear(TransformerEngineBaseModule): ...@@ -767,22 +767,10 @@ class Linear(TransformerEngineBaseModule):
bias = torch.nn.Parameter(bias) bias = torch.nn.Parameter(bias)
self.register_parameter(self.bias_names[i], bias, self.register_parameter(self.bias_names[i], bias,
init_fn=init_method_constant(0.0)) init_fn=init_method_constant(0.0))
if parallel_mode == "row":
bias.sequence_parallel = sequence_parallel
else: else:
bias = torch.Tensor().to(dtype=params_dtype, device=device) bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, self.bias_names[i], bias) setattr(self, self.bias_names[i], bias)
# Configure tensor parallelism
set_tensor_model_parallel_attributes(
tensor=weight,
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)
if parallel_mode == "column":
set_tensor_model_parallel_attributes(bias, True, 0, 1)
# Concatenated tensors are not needed if not splitting # Concatenated tensors are not needed if not splitting
# into multiple parameters # into multiple parameters
if not is_subview: if not is_subview:
...@@ -804,6 +792,27 @@ class Linear(TransformerEngineBaseModule): ...@@ -804,6 +792,27 @@ class Linear(TransformerEngineBaseModule):
else: else:
self.gemm_bias_unfused_add = False self.gemm_bias_unfused_add = False
def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
if not defer_init:
# Set parallelism attributes for linear weights
for weight in self.weight_names:
set_tensor_model_parallel_attributes(
tensor=getattr(self, weight),
is_parallel=True,
dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
# Set parallelism attributes for linear biases
if self.use_bias:
for bias in self.bias_names:
if self.parallel_mode == "row":
setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel)
elif self.parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)
def get_fp8_weights_scratchpad( def get_fp8_weights_scratchpad(
self, self,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
......
...@@ -141,7 +141,7 @@ class RMSNorm(torch.nn.Module): ...@@ -141,7 +141,7 @@ class RMSNorm(torch.nn.Module):
dtype=params_dtype, dtype=params_dtype,
) )
) )
setattr(self.weight, "sequence_parallel", sequence_parallel) self.sequence_parallel = sequence_parallel
self.reset_parameters(defer_init=(device == 'meta')) self.reset_parameters(defer_init=(device == 'meta'))
...@@ -169,7 +169,11 @@ class RMSNorm(torch.nn.Module): ...@@ -169,7 +169,11 @@ class RMSNorm(torch.nn.Module):
"""Reset RMSNorm parameters""" """Reset RMSNorm parameters"""
if defer_init: if defer_init:
return return
if self.weight.device == torch.device('meta'):
self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device='cuda'))
init.constant_(self.weight, float(not self.zero_centered_gamma)) init.constant_(self.weight, float(not self.zero_centered_gamma))
setattr(self.weight, "sequence_parallel", self.sequence_parallel)
@no_torch_dynamo() @no_torch_dynamo()
def forward(self, inp: torch.Tensor) -> torch.Tensor: def forward(self, inp: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment