Unverified Commit 6b987687 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Integration test for Megatron-LM (#1329)



* Handle deprecated `hidden_size` arg in norm modules
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Support initializing norm ops on CPU
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add integration test for Megatron-LM
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Rename Mcore integration test
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Handle case in RMSNorm where hidden dim is not provided
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b495120e
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
# Paths
: ${TE_PATH:=/opt/transformerengine}
: ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integration/Megatron-LM}
# Download Megatron-LM if needed
if [ ! -d "${MCORE_PATH}" ]; then
pushd $(dirname ${MCORE_PATH})
git clone -b core_r0.9.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM
popd
fi
# Megatron-LM invocation
COMMAND="
NVTE_TORCH_COMPILE=0
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0
NVTE_FLASH_ATTN=1
NVTE_FWD_LAYERNORM_SM_MARGIN=0
NVTE_BWD_LAYERNORM_SM_MARGIN=0
CUDA_DEVICE_MAX_CONNECTIONS=1
NVTE_BIAS_GELU_NVFUSION=0
NVTE_BIAS_DROPOUT_FUSION=0
python
-m torch.distributed.launch
--use_env
--nnodes=1
--nproc_per_node=1
${MCORE_PATH}/pretrain_gpt.py
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 1
--use-cpu-initialization
--num-layers 2
--hidden-size 128
--num-attention-heads 8
--seq-length 128
--max-position-embeddings 2048
--micro-batch-size 1
--global-batch-size 8
--train-iters 10
--eval-iters 10
--lr 1e-4
--mock-data
--vocab-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-vocab.json
--merge-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-merges.txt
--transformer-impl transformer_engine
--fp8-format hybrid
"
COMMAND=$(echo "${COMMAND}" | tr '\n' ' ')
# Launch Megatron-LM
bash -c "${COMMAND}"
...@@ -61,15 +61,32 @@ class LayerNorm(_LayerNormOp): ...@@ -61,15 +61,32 @@ class LayerNorm(_LayerNormOp):
def __init__( def __init__(
self, self,
normalized_shape: Union[Iterable[int], int], normalized_shape: Union[Iterable[int], int, None] = None,
eps: float = 1e-5, eps: float = 1e-5,
sequence_parallel: Optional[bool] = None, # legacy sequence_parallel: Optional[bool] = None, # legacy
params_dtype: Optional[torch.dtype] = None, # deprecated params_dtype: Optional[torch.dtype] = None, # deprecated
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
hidden_size: Optional[int] = None, # deprecated
**kwargs, **kwargs,
) -> None: ) -> None:
# Handle deprecated options # Handle deprecated options
if normalized_shape is None:
if hidden_size is None:
raise RuntimeError(
"Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided"
)
warnings.warn(
"`hidden_size` arg has been renamed to `normalized_shape` "
"for compatibility with `torch.nn.LayerNorm`.",
DeprecationWarning,
stacklevel=2,
)
normalized_shape = hidden_size
elif hidden_size is not None:
raise RuntimeError(
"Both `normalized_shape` and `hidden_size` (deprecated) args are provided"
)
if params_dtype is not None: if params_dtype is not None:
if "dtype" in kwargs: if "dtype" in kwargs:
raise RuntimeError( raise RuntimeError(
......
...@@ -65,15 +65,32 @@ class RMSNorm(_RMSNormOp): ...@@ -65,15 +65,32 @@ class RMSNorm(_RMSNormOp):
def __init__( def __init__(
self, self,
normalized_shape: Union[Iterable[int], int], normalized_shape: Union[Iterable[int], int, None] = None,
eps: float = 1e-5, eps: float = 1e-5,
sequence_parallel: Optional[bool] = None, # legacy sequence_parallel: Optional[bool] = None, # legacy
params_dtype: Optional[torch.dtype] = None, # deprecated params_dtype: Optional[torch.dtype] = None, # deprecated
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
hidden_size: Optional[int] = None, # deprecated
**kwargs, **kwargs,
) -> None: ) -> None:
# Handle deprecated options # Handle deprecated options
if normalized_shape is None:
if hidden_size is None:
raise RuntimeError(
"Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided"
)
warnings.warn(
"`hidden_size` arg has been renamed to `normalized_shape` "
"for compatibility with `torch.nn.LayerNorm`.",
DeprecationWarning,
stacklevel=2,
)
normalized_shape = hidden_size
elif hidden_size is not None:
raise RuntimeError(
"Both `normalized_shape` and `hidden_size` (deprecated) args are provided"
)
if params_dtype is not None: if params_dtype is not None:
if "dtype" in kwargs: if "dtype" in kwargs:
raise RuntimeError( raise RuntimeError(
......
...@@ -20,7 +20,12 @@ from ...cpp_extensions import ( ...@@ -20,7 +20,12 @@ from ...cpp_extensions import (
) )
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...tensor import Float8Tensor, QuantizedTensor from ...tensor import Float8Tensor, QuantizedTensor
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data from ...utils import (
canonicalize_device,
canonicalize_dtype,
clear_tensor_data,
devices_match,
)
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, reshape from .._common import maybe_autocast_dtype, reshape
...@@ -84,28 +89,23 @@ class LayerNorm(BasicOperation): ...@@ -84,28 +89,23 @@ class LayerNorm(BasicOperation):
normalized_shape = (normalized_shape,) normalized_shape = (normalized_shape,)
else: else:
normalized_shape = tuple(normalized_shape) normalized_shape = tuple(normalized_shape)
self._shape: tuple[int, ...] = normalized_shape
# Parameter device # Parameter device
defer_param_init = False defer_param_init = False
device = canonicalize_device(device) device = canonicalize_device(device)
if device.type == "meta": if device.type == "meta":
defer_param_init = True defer_param_init = True
device = canonicalize_device(None)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
self.device: torch.device = device
# Initialize parameters if needed # Initialize parameters if needed
dtype = canonicalize_dtype(dtype) dtype = canonicalize_dtype(dtype)
weight = torch.empty( weight = torch.empty(
self._shape, normalized_shape,
device="meta", device=device,
dtype=dtype, dtype=dtype,
) )
bias = torch.empty( bias = torch.empty(
self._shape, normalized_shape,
device="meta", device=device,
dtype=dtype, dtype=dtype,
) )
weight = torch.nn.Parameter(weight) weight = torch.nn.Parameter(weight)
...@@ -143,17 +143,18 @@ class LayerNorm(BasicOperation): ...@@ -143,17 +143,18 @@ class LayerNorm(BasicOperation):
def reset_parameters(self) -> None: def reset_parameters(self) -> None:
"""Initialize parameter buffers and values""" """Initialize parameter buffers and values"""
# Make sure parameter is initialized # Parameter device
weight = self.weight weight = self.weight
bias = self.bias bias = self.bias
if weight.device.type != "cuda": device = weight.device
weight = torch.empty_like(weight, device=self.device) if device.type == "meta":
else: device = canonicalize_device(None)
weight = weight.to(device=self.device)
if bias.device.type != "cuda": # Initialize param buffers
bias = torch.empty_like(bias, device=self.device) if not devices_match(weight.device, device):
else: weight = torch.empty_like(weight, device=device)
bias = bias.to(device=self.device) if not devices_match(bias.device, device):
bias = torch.empty_like(bias, device=device)
# Initialize values # Initialize values
if self.zero_centered_gamma: if self.zero_centered_gamma:
...@@ -184,17 +185,21 @@ class LayerNorm(BasicOperation): ...@@ -184,17 +185,21 @@ class LayerNorm(BasicOperation):
) -> torch.Tensor: ) -> torch.Tensor:
# Check tensor dims # Check tensor dims
weight = self.weight
weight_dims = tuple(weight.size())
input_dims = tuple(input_.size()) input_dims = tuple(input_.size())
if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape: if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims:
raise ValueError( raise ValueError(
f"Input tensor (shape={input_dims}) " f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={self._shape}) are not compatible" f"and weight tensor (shape={weight_dims}) are not compatible"
) )
# Check input tensors # Check input tensors
inner_dim = math.prod(self._shape) inner_dim = math.prod(weight_dims)
device = self.device device = weight.device
dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) if device.type != "cuda":
device = canonicalize_device(None)
dtype = maybe_autocast_dtype(default_dtype=weight.dtype)
x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype) b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype)
...@@ -266,6 +271,7 @@ class LayerNorm(BasicOperation): ...@@ -266,6 +271,7 @@ class LayerNorm(BasicOperation):
# Save state for backward pass # Save state for backward pass
if requires_grad: if requires_grad:
ctx.save_for_backward(x, means, rstdevs) ctx.save_for_backward(x, means, rstdevs)
ctx.device = device
ctx.dtype = dtype ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None ctx.has_prev_op = prev_op is not None
...@@ -282,9 +288,12 @@ class LayerNorm(BasicOperation): ...@@ -282,9 +288,12 @@ class LayerNorm(BasicOperation):
# Saved tensors from forward pass # Saved tensors from forward pass
x, means, rstdevs = ctx.saved_tensors x, means, rstdevs = ctx.saved_tensors
# Tensor dims
weight_dims = self.weight.size()
inner_dim = math.prod(weight_dims)
# Check input tensors # Check input tensors
inner_dim = x.size(-1) device = ctx.device
device = self.device
dtype = ctx.dtype dtype = ctx.dtype
dy = reshape(grad_output, x.size(), device=device, dtype=dtype) dy = reshape(grad_output, x.size(), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
...@@ -312,6 +321,6 @@ class LayerNorm(BasicOperation): ...@@ -312,6 +321,6 @@ class LayerNorm(BasicOperation):
# Reshape results # Reshape results
grad_input = reshape(dx, grad_output.size()) grad_input = reshape(dx, grad_output.size())
grad_weight = reshape(dw, self._shape) grad_weight = reshape(dw, weight_dims)
grad_bias = reshape(db, self._shape) grad_bias = reshape(db, weight_dims)
return grad_input, (grad_weight, grad_bias) return grad_input, (grad_weight, grad_bias)
...@@ -20,7 +20,12 @@ from ...cpp_extensions import ( ...@@ -20,7 +20,12 @@ from ...cpp_extensions import (
) )
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...tensor import Float8Tensor, QuantizedTensor from ...tensor import Float8Tensor, QuantizedTensor
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data from ...utils import (
canonicalize_device,
canonicalize_dtype,
clear_tensor_data,
devices_match,
)
from ..op import BasicOperation, OperationContext from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, reshape from .._common import maybe_autocast_dtype, reshape
...@@ -83,22 +88,17 @@ class RMSNorm(BasicOperation): ...@@ -83,22 +88,17 @@ class RMSNorm(BasicOperation):
normalized_shape = (normalized_shape,) normalized_shape = (normalized_shape,)
else: else:
normalized_shape = tuple(normalized_shape) normalized_shape = tuple(normalized_shape)
self._shape: tuple[int, ...] = normalized_shape
# Parameter device # Parameter device
defer_param_init = False defer_param_init = False
device = canonicalize_device(device) device = canonicalize_device(device)
if device.type == "meta": if device.type == "meta":
defer_param_init = True defer_param_init = True
device = canonicalize_device(None)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
self.device: torch.device = device
# Initialize parameters if needed # Initialize parameters if needed
weight = torch.empty( weight = torch.empty(
self._shape, normalized_shape,
device="meta", device=device,
dtype=canonicalize_dtype(dtype), dtype=canonicalize_dtype(dtype),
) )
weight = torch.nn.Parameter(weight) weight = torch.nn.Parameter(weight)
...@@ -133,12 +133,15 @@ class RMSNorm(BasicOperation): ...@@ -133,12 +133,15 @@ class RMSNorm(BasicOperation):
def reset_parameters(self) -> None: def reset_parameters(self) -> None:
"""Initialize parameter buffers and values""" """Initialize parameter buffers and values"""
# Make sure parameter is initialized # Parameter device
weight = self.weight weight = self.weight
if weight.device.type != "cuda": device = weight.device
weight = torch.empty_like(weight, device=self.device) if device.type == "meta":
else: device = canonicalize_device(None)
weight = weight.to(device=self.device)
# Initialize param buffers
if not devices_match(weight.device, device):
weight = torch.empty_like(weight, device=device)
# Initialize values # Initialize values
if self.zero_centered_gamma: if self.zero_centered_gamma:
...@@ -165,17 +168,21 @@ class RMSNorm(BasicOperation): ...@@ -165,17 +168,21 @@ class RMSNorm(BasicOperation):
) -> torch.Tensor: ) -> torch.Tensor:
# Check tensor dims # Check tensor dims
weight = self.weight
weight_dims = tuple(weight.size())
input_dims = tuple(input_.size()) input_dims = tuple(input_.size())
if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape: if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims:
raise ValueError( raise ValueError(
f"Input tensor (shape={input_dims}) " f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={self._shape}) are not compatible" f"and weight tensor (shape={weight_dims}) are not compatible"
) )
# Check input tensors # Check input tensors
inner_dim = math.prod(self._shape) inner_dim = math.prod(weight_dims)
device = self.device device = weight.device
dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) if device.type != "cuda":
device = canonicalize_device(None)
dtype = maybe_autocast_dtype(default_dtype=weight.dtype)
x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
if isinstance(x, QuantizedTensor): if isinstance(x, QuantizedTensor):
...@@ -241,6 +248,7 @@ class RMSNorm(BasicOperation): ...@@ -241,6 +248,7 @@ class RMSNorm(BasicOperation):
# Save state for backward pass # Save state for backward pass
if requires_grad: if requires_grad:
ctx.save_for_backward(x, rstdevs) ctx.save_for_backward(x, rstdevs)
ctx.device = device
ctx.dtype = dtype ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None ctx.has_prev_op = prev_op is not None
...@@ -257,9 +265,12 @@ class RMSNorm(BasicOperation): ...@@ -257,9 +265,12 @@ class RMSNorm(BasicOperation):
# Saved tensors from forward pass # Saved tensors from forward pass
x, rstdevs = ctx.saved_tensors x, rstdevs = ctx.saved_tensors
# Tensor dims
weight_dims = self.weight.size()
inner_dim = math.prod(weight_dims)
# Check input tensors # Check input tensors
inner_dim = x.size(-1) device = ctx.device
device = self.device
dtype = ctx.dtype dtype = ctx.dtype
dy = reshape(grad_output, x.size(), device=device, dtype=dtype) dy = reshape(grad_output, x.size(), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
...@@ -285,5 +296,5 @@ class RMSNorm(BasicOperation): ...@@ -285,5 +296,5 @@ class RMSNorm(BasicOperation):
# Reshape results # Reshape results
grad_input = reshape(dx, grad_output.size()) grad_input = reshape(dx, grad_output.size())
grad_weight = reshape(dw, self._shape) grad_weight = reshape(dw, weight_dims)
return grad_input, (grad_weight,) return grad_input, (grad_weight,)
...@@ -135,7 +135,11 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -135,7 +135,11 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs) requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
for idx in basic_op_idxs: for idx in basic_op_idxs:
basic_op_ctxs[idx].requires_grad = requires_grad basic_op_ctxs[idx].requires_grad = requires_grad
x.requires_grad_(requires_grad=requires_grad) if requires_grad != x.requires_grad:
if requires_grad:
x.requires_grad_()
else:
x = x.detach()
# Forward op # Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment