Commit 970620a5 authored by wenjh's avatar wenjh
Browse files

merge nv_release_v2.10 to release_v2.10


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents c1a1c04e 769ed778
...@@ -24,7 +24,7 @@ class AllReduce(BasicOperation): ...@@ -24,7 +24,7 @@ class AllReduce(BasicOperation):
Parameters Parameters
---------- ----------
process_group: torch.distributed.ProcessGroup, default = world group process_group : torch.distributed.ProcessGroup, default = world group
Process group for communication Process group for communication
""" """
......
...@@ -25,7 +25,6 @@ from ...module.base import ( ...@@ -25,7 +25,6 @@ from ...module.base import (
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
get_dummy_wgrad, get_dummy_wgrad,
get_workspace,
) )
from ...tensor import Quantizer from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer from ...tensor.float8_tensor import Float8Quantizer
...@@ -54,27 +53,27 @@ class BasicLinear(BasicOperation): ...@@ -54,27 +53,27 @@ class BasicLinear(BasicOperation):
Parameters Parameters
---------- ----------
in_features: int in_features : int
Inner dimension of input tensor Inner dimension of input tensor
out_features: int out_features : int
Inner dimension of output tensor Inner dimension of output tensor
device: torch.device, default = default CUDA device device : torch.device, default = default CUDA device
Tensor device Tensor device
dtype: torch.dtype, default = default dtype dtype : torch.dtype, default = default dtype
Tensor datatype Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None` tensor_parallel_mode : {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group tensor_parallel_group : torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism Process group for tensor parallelism
sequence_parallel: bool, default = `False` sequence_parallel : bool, default = `False`
Whether to apply sequence parallelism together with tensor Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors along parallelism, i.e. distributing input or output tensors along
outer dimension (sequence or batch dim) when not distributing outer dimension (sequence or batch dim) when not distributing
along inner dimension (embedding dim) along inner dimension (embedding dim)
rng_state_tracker_function: callable rng_state_tracker_function : callable
Function that returns `CudaRNGStatesTracker`, which is used Function that returns `CudaRNGStatesTracker`, which is used
for model-parallel weight initialization for model-parallel weight initialization
accumulate_into_main_grad: bool, default = `False` accumulate_into_main_grad : bool, default = `False`
Whether to directly accumulate weight gradients into the Whether to directly accumulate weight gradients into the
weight's `main_grad` attribute instead of relying on PyTorch weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and autograd. The weight's `main_grad` must be set externally and
...@@ -138,8 +137,10 @@ class BasicLinear(BasicOperation): ...@@ -138,8 +137,10 @@ class BasicLinear(BasicOperation):
out_features=out_features, out_features=out_features,
) )
# Whether weight tensor is natively quantized # Initialize recipe state if needed for natively quantized weight
self._with_quantized_weight: bool = FP8GlobalStateManager.with_fp8_parameters() self._with_quantized_weight: bool = FP8GlobalStateManager.with_fp8_parameters()
if self._with_quantized_weight:
self.reset_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe())
# Initialize parameters if needed # Initialize parameters if needed
weight = torch.empty( weight = torch.empty(
...@@ -585,7 +586,6 @@ class BasicLinear(BasicOperation): ...@@ -585,7 +586,6 @@ class BasicLinear(BasicOperation):
y, *_ = general_gemm( y, *_ = general_gemm(
w, w,
x, x,
get_workspace(),
out_dtype=dtype, out_dtype=dtype,
quantization_params=output_quantizer, quantization_params=output_quantizer,
alpha=alpha, alpha=alpha,
...@@ -875,7 +875,6 @@ class BasicLinear(BasicOperation): ...@@ -875,7 +875,6 @@ class BasicLinear(BasicOperation):
dx, *_ = general_gemm( dx, *_ = general_gemm(
w, w,
dy, dy,
get_workspace(),
out_dtype=dtype, out_dtype=dtype,
quantization_params=grad_input_quantizer, quantization_params=grad_input_quantizer,
alpha=grad_input_alpha, alpha=grad_input_alpha,
...@@ -928,7 +927,6 @@ class BasicLinear(BasicOperation): ...@@ -928,7 +927,6 @@ class BasicLinear(BasicOperation):
dw, *_ = general_gemm( dw, *_ = general_gemm(
x, x,
dy, dy,
get_workspace(),
out_dtype=dw_dtype, out_dtype=dw_dtype,
alpha=grad_weight_alpha, alpha=grad_weight_alpha,
beta=grad_weight_beta, beta=grad_weight_beta,
......
...@@ -22,16 +22,16 @@ class Bias(BasicOperation): ...@@ -22,16 +22,16 @@ class Bias(BasicOperation):
Parameters Parameters
---------- ----------
size: int size : int
Inner dimension of input tensor Inner dimension of input tensor
device: torch.device, default = default CUDA device device : torch.device, default = default CUDA device
Tensor device Tensor device
dtype: torch.dtype, default = default dtype dtype : torch.dtype, default = default dtype
Tensor datatype Tensor datatype
tensor_parallel: bool, default = `False` tensor_parallel : bool, default = `False`
Whether to distribute input tensor and bias tensors along Whether to distribute input tensor and bias tensors along
inner dimension inner dimension
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group tensor_parallel_group : torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism Process group for tensor parallelism
""" """
......
...@@ -10,7 +10,7 @@ import os ...@@ -10,7 +10,7 @@ import os
import torch import torch
from ... import torch_version from ...torch_version import torch_version
from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...jit import ( from ...jit import (
l2normalization_fused, l2normalization_fused,
...@@ -40,11 +40,11 @@ class L2Normalization(BasicOperation): ...@@ -40,11 +40,11 @@ class L2Normalization(BasicOperation):
---------- ----------
eps : float, default = 1e-6 eps : float, default = 1e-6
A value added to the denominator for numerical stability A value added to the denominator for numerical stability
seq_length: int, default = None seq_length : int, default = None
sequence length of input samples. Needed for JIT Warmup, a technique where jit fused sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
functions are warmed up before training to ensure same kernels are used for forward functions are warmed up before training to ensure same kernels are used for forward
propagation and activation recompute phase. propagation and activation recompute phase.
micro_batch_size: int, default = None micro_batch_size : int, default = None
batch size per training step. Needed for JIT Warmup, a technique where jit batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are fused functions are warmed up before training to ensure same kernels are
used for forward propagation and activation recompute phase. used for forward propagation and activation recompute phase.
......
...@@ -42,14 +42,14 @@ class LayerNorm(BasicOperation): ...@@ -42,14 +42,14 @@ class LayerNorm(BasicOperation):
Parameters Parameters
---------- ----------
normalized_shape: int or iterable of int normalized_shape : int or iterable of int
Inner dimensions of input tensor Inner dimensions of input tensor
eps : float, default = 1e-5 eps : float, default = 1e-5
A value added to the denominator of layer normalization for A value added to the denominator of layer normalization for
numerical stability numerical stability
device: torch.device, default = default CUDA device device : torch.device, default = default CUDA device
Tensor device Tensor device
dtype: torch.dtype, default = default dtype dtype : torch.dtype, default = default dtype
Tensor datatype Tensor datatype
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero If `True`, the :math:`\gamma` parameter is initialized to zero
...@@ -58,7 +58,7 @@ class LayerNorm(BasicOperation): ...@@ -58,7 +58,7 @@ class LayerNorm(BasicOperation):
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) + \beta
sm_margin: int or dict, default = 0 sm_margin : int or dict, default = 0
Number of SMs to exclude when launching CUDA kernels. This Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels. helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM For more fine-grained control, provide a dict with the SM
......
...@@ -23,9 +23,9 @@ class Quantize(BasicOperation): ...@@ -23,9 +23,9 @@ class Quantize(BasicOperation):
Parameters Parameters
---------- ----------
forward: bool, default = `True` forward : bool, default = `True`
Perform quantization in forward pass Perform quantization in forward pass
backward: bool, default = `False` backward : bool, default = `False`
Perform quantization in backward pass Perform quantization in backward pass
""" """
......
...@@ -23,7 +23,7 @@ class ReduceScatter(BasicOperation): ...@@ -23,7 +23,7 @@ class ReduceScatter(BasicOperation):
Parameters Parameters
---------- ----------
process_group: torch.distributed.ProcessGroup, default = world group process_group : torch.distributed.ProcessGroup, default = world group
Process group for communication Process group for communication
""" """
......
...@@ -24,7 +24,7 @@ class Reshape(BasicOperation): ...@@ -24,7 +24,7 @@ class Reshape(BasicOperation):
Parameters Parameters
---------- ----------
shape: iterable of int shape : iterable of int
Output tensor dimensions. If one dimension is -1, it is Output tensor dimensions. If one dimension is -1, it is
inferred based on input tensor dimensions. inferred based on input tensor dimensions.
......
...@@ -42,13 +42,13 @@ class RMSNorm(BasicOperation): ...@@ -42,13 +42,13 @@ class RMSNorm(BasicOperation):
Parameters Parameters
---------- ----------
normalized_shape: int or iterable of int normalized_shape : int or iterable of int
Inner dimensions of input tensor Inner dimensions of input tensor
eps : float, default = 1e-5 eps : float, default = 1e-5
A value added to the denominator for numerical stability A value added to the denominator for numerical stability
device: torch.device, default = default CUDA device device : torch.device, default = default CUDA device
Tensor device Tensor device
dtype: torch.dtype, default = default dtype dtype : torch.dtype, default = default dtype
Tensor datatype Tensor datatype
zero_centered_gamma : bool, default = 'False' zero_centered_gamma : bool, default = 'False'
If `True`, the :math:`\gamma` parameter is initialized to zero If `True`, the :math:`\gamma` parameter is initialized to zero
...@@ -57,7 +57,7 @@ class RMSNorm(BasicOperation): ...@@ -57,7 +57,7 @@ class RMSNorm(BasicOperation):
.. math:: .. math::
y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma) y = \frac{x}{\sqrt{\mathrm{Var}[x] + \varepsilon}} * (1 + \gamma)
sm_margin: int, default = 0 sm_margin : int, default = 0
Number of SMs to exclude when launching CUDA kernels. This Number of SMs to exclude when launching CUDA kernels. This
helps overlap with other kernels, e.g. communication kernels. helps overlap with other kernels, e.g. communication kernels.
For more fine-grained control, provide a dict with the SM For more fine-grained control, provide a dict with the SM
...@@ -249,4 +249,6 @@ class RMSNorm(BasicOperation): ...@@ -249,4 +249,6 @@ class RMSNorm(BasicOperation):
) -> torch.Tensor: ) -> torch.Tensor:
"""Every operand in this function has a defined ONNX translation.""" """Every operand in this function has a defined ONNX translation."""
weight = self.weight + 1 if self.zero_centered_gamma else self.weight weight = self.weight + 1 if self.zero_centered_gamma else self.weight
return torch.nn.functional.rms_norm(input_, input_.shape[-1:], weight, self.eps) variance = input_.pow(2).mean(-1, keepdim=True)
normalized = input_ * torch.rsqrt(variance + self.eps)
return normalized * weight
...@@ -90,15 +90,15 @@ def fuse_backward_activation_bias( ...@@ -90,15 +90,15 @@ def fuse_backward_activation_bias(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Backward pass operations and the indices of the corresponding Backward pass operations and the indices of the corresponding
basic operations. basic operations.
recipe: Recipe, optional recipe : Recipe, optional
Used quantization recipe Used quantization recipe
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated backward pass operations Updated backward pass operations
""" """
......
...@@ -87,13 +87,13 @@ def fuse_backward_add_rmsnorm( ...@@ -87,13 +87,13 @@ def fuse_backward_add_rmsnorm(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Backward pass operations and the indices of the corresponding Backward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated backward pass operations Updated backward pass operations
""" """
......
...@@ -119,13 +119,13 @@ def fuse_backward_linear_add( ...@@ -119,13 +119,13 @@ def fuse_backward_linear_add(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Backward pass operations and the indices of the corresponding Backward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated backward pass operations Updated backward pass operations
""" """
......
...@@ -119,13 +119,13 @@ def fuse_backward_linear_scale( ...@@ -119,13 +119,13 @@ def fuse_backward_linear_scale(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Backward pass operations and the indices of the corresponding Backward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated backward pass operations Updated backward pass operations
""" """
......
...@@ -142,13 +142,13 @@ def fuse_forward_linear_bias_activation( ...@@ -142,13 +142,13 @@ def fuse_forward_linear_bias_activation(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Forward pass operations and the indices of the corresponding Forward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated forward pass operations Updated forward pass operations
""" """
......
...@@ -139,13 +139,13 @@ def fuse_forward_linear_bias_add( ...@@ -139,13 +139,13 @@ def fuse_forward_linear_bias_add(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Forward pass operations and the indices of the corresponding Forward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated forward pass operations Updated forward pass operations
""" """
......
...@@ -118,13 +118,13 @@ def fuse_forward_linear_scale_add( ...@@ -118,13 +118,13 @@ def fuse_forward_linear_scale_add(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Forward pass operations and the indices of the corresponding Forward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated forward pass operations Updated forward pass operations
""" """
......
...@@ -19,7 +19,6 @@ from ...module.base import ( ...@@ -19,7 +19,6 @@ from ...module.base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_dummy_wgrad, get_dummy_wgrad,
get_ub, get_ub,
get_workspace,
) )
from ...quantized_tensor import Quantizer from ...quantized_tensor import Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer
...@@ -378,7 +377,6 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -378,7 +377,6 @@ class UserbuffersBackwardLinear(FusedOperation):
dx, *_ = general_gemm( dx, *_ = general_gemm(
w, w,
dy, dy,
get_workspace(),
out_dtype=dtype, out_dtype=dtype,
quantization_params=grad_input_quantizer, quantization_params=grad_input_quantizer,
layout="NN", layout="NN",
...@@ -464,7 +462,6 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -464,7 +462,6 @@ class UserbuffersBackwardLinear(FusedOperation):
dw, *_ = general_gemm( dw, *_ = general_gemm(
x, x,
dy, dy,
get_workspace(),
out_dtype=dw_dtype, out_dtype=dw_dtype,
accumulate=accumulate_into_grad_weight, accumulate=accumulate_into_grad_weight,
layout="NT", layout="NT",
...@@ -592,13 +589,13 @@ def fuse_userbuffers_backward_linear( ...@@ -592,13 +589,13 @@ def fuse_userbuffers_backward_linear(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Backward pass operations and the indices of the corresponding Backward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated backward pass operations Updated backward pass operations
""" """
......
...@@ -18,7 +18,6 @@ from ...quantization import FP8GlobalStateManager ...@@ -18,7 +18,6 @@ from ...quantization import FP8GlobalStateManager
from ...module.base import ( from ...module.base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_ub, get_ub,
get_workspace,
_2X_ACC_FPROP, _2X_ACC_FPROP,
) )
from ...quantized_tensor import Quantizer from ...quantized_tensor import Quantizer
...@@ -243,7 +242,6 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -243,7 +242,6 @@ class UserbuffersForwardLinear(FusedOperation):
gemm_output, *_, reduce_scatter_output = general_gemm( gemm_output, *_, reduce_scatter_output = general_gemm(
w, w,
x, x,
get_workspace(),
out_dtype=dtype, out_dtype=dtype,
quantization_params=output_quantizer, quantization_params=output_quantizer,
bias=bias, bias=bias,
...@@ -379,13 +377,13 @@ def fuse_userbuffers_forward_linear( ...@@ -379,13 +377,13 @@ def fuse_userbuffers_forward_linear(
Parameters Parameters
---------- ----------
ops: list of tuples ops : list of tuples
Forward pass operations and the indices of the corresponding Forward pass operations and the indices of the corresponding
basic operations. basic operations.
Returns Returns
------- -------
ops: list of tuples ops : list of tuples
Updated forward pass operations Updated forward pass operations
""" """
......
...@@ -310,7 +310,7 @@ class OperationFuser: ...@@ -310,7 +310,7 @@ class OperationFuser:
Parameters Parameters
---------- ----------
ops: list of FusibleOperation ops : list of FusibleOperation
Pipeline of operations Pipeline of operations
""" """
......
...@@ -27,29 +27,29 @@ class Linear(FusedOperation): ...@@ -27,29 +27,29 @@ class Linear(FusedOperation):
Parameters Parameters
---------- ----------
in_features: int in_features : int
Inner dimension of input tensor Inner dimension of input tensor
out_features: int out_features : int
Inner dimension of output tensor Inner dimension of output tensor
bias: bool, default = `True` bias : bool, default = `True`
Apply additive bias Apply additive bias
device: torch.device, default = default CUDA device device : torch.device, default = default CUDA device
Tensor device Tensor device
dtype: torch.dtype, default = default dtype dtype : torch.dtype, default = default dtype
Tensor datatype Tensor datatype
tensor_parallel_mode: {`None`, "column", "row"}, default = `None` tensor_parallel_mode : {`None`, "column", "row"}, default = `None`
Mode for tensor parallelism Mode for tensor parallelism
tensor_parallel_group: torch.distributed.ProcessGroup, default = world group tensor_parallel_group : torch.distributed.ProcessGroup, default = world group
Process group for tensor parallelism Process group for tensor parallelism
sequence_parallel: bool, default = `False` sequence_parallel : bool, default = `False`
Whether to apply sequence parallelism together with tensor Whether to apply sequence parallelism together with tensor
parallelism, i.e. distributing input or output tensors along parallelism, i.e. distributing input or output tensors along
outer dimension (sequence or batch dim) when not distributing outer dimension (sequence or batch dim) when not distributing
along inner dimension (embedding dim) along inner dimension (embedding dim)
rng_state_tracker_function: callable rng_state_tracker_function : callable
Function that returns CudaRNGStatesTracker, which is used for Function that returns CudaRNGStatesTracker, which is used for
model-parallel weight initialization model-parallel weight initialization
accumulate_into_main_grad: bool, default = `False` accumulate_into_main_grad : bool, default = `False`
Whether to directly accumulate weight gradients into the Whether to directly accumulate weight gradients into the
weight's `main_grad` attribute instead of relying on PyTorch weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and autograd. The weight's `main_grad` must be set externally and
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment