Commit 2389ed3f authored by yuguo's avatar yuguo
Browse files

Merge branch 'release_v2.7' of https://github.com/NVIDIA/TransformerEngine into release_v2.7

parents 87e3e56e 58c3ac80
[submodule "3rdparty/googletest"]
path = 3rdparty/googletest
url = https://github.com/google/googletest.git
[submodule "3rdparty/cudnn-frontend"]
path = 3rdparty/cudnn-frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
[submodule "3rdparty/hipify_torch"]
path = 3rdparty/hipify_torch
url = https://github.com/ROCm/hipify_torch.git
......@@ -219,7 +219,9 @@ def train_and_evaluate(args):
else:
fp8_recipe = None
with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe):
with te.fp8_autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
):
encoder = Net(num_embed)
# We use nn.Embed, thus inputs need to be in int
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
......
......@@ -193,7 +193,9 @@ def train_and_evaluate(args):
else:
fp8_recipe = None
with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe):
with te.fp8_autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
):
cnn = Net(args.use_te)
var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
tx = optax.sgd(args.lr, args.momentum)
......
......@@ -173,7 +173,7 @@ class TestDistributedLayernormMLP:
)
# Single GPU
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
single_jitter = jax.jit(
value_and_grad_func,
static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
......@@ -330,7 +330,7 @@ class TestDistributedLayernormMLP:
with use_jax_gemm(enabled=with_jax_gemm):
# Single GPUs
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type,
intermediate_dim=INTERMEDIATE,
......
......@@ -28,6 +28,7 @@ from transformer_engine.jax.quantize import (
is_fp8_available,
update_collections,
)
from transformer_engine.jax.sharding import MeshResource, global_shard_guard
@pytest.fixture(autouse=True, scope="function")
......@@ -490,11 +491,17 @@ class BaseTester:
def test_forward(self, data_shape, dtype, attrs):
"""Test normal datatype forward"""
QuantizeConfig.finalize() # Ensure FP8 disabled.
with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_forward(data_shape, dtype)
def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward"""
QuantizeConfig.finalize() # Ensure FP8 disabled.
with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_backward(data_shape, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
......@@ -502,6 +509,9 @@ class BaseTester:
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test forward with fp8 enabled"""
QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
QuantizeConfig.finalize()
......@@ -510,6 +520,9 @@ class BaseTester:
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test backward with fp8 enabled"""
QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
QuantizeConfig.finalize()
......
......@@ -274,6 +274,8 @@ model_configs_mla = {
"mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
"mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference
}
......
......@@ -252,8 +252,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 &&
cudnn_runtime_version >= 91100)) &&
// 9.11/9.12 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
(!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200) && is_training &&
sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 &&
(!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200 ||
cudnn_runtime_version == 91300) &&
is_training && sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 &&
!(head_dim_qk == 192 && head_dim_v == 128) && head_dim_qk != head_dim_v))) &&
// bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
......
......@@ -532,22 +532,22 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&epilogue, sizeof(epilogue)));
if (counter != nullptr) {
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ",
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
CUDA_VERSION);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ",
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
CUBLAS_VERSION);
#endif
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \
CUBLAS_VERSION < 130000
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA verson is ",
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ",
cuda::cudart_version());
NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ",
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cublas_version());
if (m_split == 0) m_split = 1;
if (n_split == 0) n_split = 1;
......@@ -783,20 +783,22 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#ifndef __HIP_PLATFORM_AMD__
// Check CUDA and cuBLAS versions
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ",
#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
CUDA_VERSION);
#endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR("Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ",
NVTE_ERROR(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
CUBLAS_VERSION);
#endif
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA verson is ",
NVTE_CHECK(
cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ",
cuda::cudart_version());
NVTE_CHECK(
cublas_version() >= 120205 && cublas_version() < 130000,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ",
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cublas_version());
#endif
......
......@@ -18,7 +18,6 @@
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
......@@ -185,12 +184,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
}
}
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
// store to global memory.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
// Step 3: Store cast output, Step 4: do transpose within thread tile
OVecCast tmp_output_c;
......@@ -426,12 +419,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
}
}
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
// store to global memory.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
// Step 3: Store cast output, Step 4: do transpose within thread tile
// Edge case: in the non-full tile case, there are three subcases
// for full thread tile, it's the same thing here
......@@ -939,15 +926,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
#else
const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM);
const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM);
dim3 grid(num_blocks_x, num_blocks_y, 1);
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
cudaLaunchConfig_t cfg = {grid, THREADS_PER_BLOCK, 0, stream, NULL, 0};
if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >= 90) {
cfg.attrs = attribute;
cfg.numAttrs = 1;
}
#endif
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
......@@ -962,6 +940,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0;
#else
dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile =
row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0;
#endif
......@@ -972,21 +951,19 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
tensor_map_output_trans =
get_tensor_map<OutputType>(output_t, num_rows, row_length);
}
cudaLaunchKernelEx(&cfg,
block_scaled_cast_transpose_kernel<kReturnTranspose, float,
InputType, OutputType>,
block_scaled_cast_transpose_kernel<kReturnTranspose, float, InputType, OutputType>
<<<grid, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x,
scale_t_stride_y, epsilon, tensor_map_output_trans, pow_2_scale);
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
tensor_map_output_trans, pow_2_scale);
} else {
cudaLaunchKernelEx(
&cfg,
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float,
InputType, OutputType>,
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType,
OutputType>
<<<grid, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
......
......@@ -24,7 +24,6 @@
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/util/cuda_runtime.h"
#include "common/utils.cuh"
namespace transformer_engine {
......@@ -252,14 +251,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
__syncthreads();
// If not return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return rowwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (!return_columnwise_gemm_ready && !return_columnwise_compact) {
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// Step 2: Cast and store to output_c
if (return_rowwise) {
constexpr int r_stride =
......@@ -365,14 +356,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}
// If return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return columnwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (return_columnwise_gemm_ready || return_columnwise_compact) {
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t
if (return_columnwise_gemm_ready) {
constexpr int c_stride =
......@@ -1448,12 +1431,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
#else
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim);
dim3 grid(num_blocks_x, num_blocks_y, 1);
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
#endif
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType,
......@@ -1463,6 +1441,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0;
#else
dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0;
#endif
TRANSFORMER_ENGINE_SWITCH_CONDITION(
......@@ -1532,30 +1511,21 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
}
#else
size_t smem_bytes = kSMemSize * sizeof(InputType);
cudaLaunchConfig_t cfg = {grid, kThreadsPerBlock, smem_bytes, stream, NULL, 0};
if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >=
90) {
cfg.attrs = attribute;
cfg.numAttrs = 1;
}
// shared memory must be requested up
if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute(
&block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size.");
} cudaLaunchKernelEx(&cfg,
block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType,
OutputType>,
} block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>
<<<grid, kThreadsPerBlock, smem_bytes, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x,
scale_t_stride_y, epsilon, rowwise_option, columnwise_option,
pow2_scale);
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x,
scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option,
columnwise_option, pow2_scale);
#endif
) // kAligned
) // OutputType
......
......@@ -205,11 +205,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], parity);
// Trigger the next kernel, so its TMA load can be overlapped with the current kernel
if (stage == STAGES - 1) {
cudaTriggerProgrammaticLaunchCompletion();
}
float thread_amax = 0.0f;
if constexpr (COLWISE_SCALING) {
const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise;
......@@ -1139,13 +1134,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
cudaLaunchConfig_t cfg = {grid, block_size, dshmem_size, stream, NULL, 0};
// This kernel will only be called on sm100+, so no need to check sm_arch
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1; cfg.attrs = attribute;
cfg.numAttrs = 1;
switch (scaling_type) {
case ScalingType::ROWWISE:
cudaFuncSetAttribute(
......@@ -1153,13 +1141,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cudaLaunchKernelEx(
&cfg,
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
case ScalingType::COLWISE:
cudaFuncSetAttribute(
......@@ -1167,13 +1155,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cudaLaunchKernelEx(
&cfg,
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute(
......@@ -1181,13 +1169,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cudaLaunchKernelEx(
&cfg,
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, true,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
}
......
......@@ -8,6 +8,7 @@ import operator
from collections.abc import Iterable
from typing import Tuple, Sequence, Union
from functools import partial, reduce
import warnings
import jax
import jax.numpy as jnp
......@@ -34,6 +35,7 @@ from ..quantize import (
is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv,
)
from ..sharding import global_mesh_resource
from .misc import get_padded_spec
......@@ -490,7 +492,8 @@ class GemmPrimitive(BasePrimitive):
# Non-contracting dims of RHS always needs to be gathered along the FSDP axis
rhs_non_cspecs = tuple(
None if spec is not None and "fsdp" in spec else spec for spec in rhs_non_cspecs
None if spec is not None and spec == global_mesh_resource().fsdp_resource else spec
for spec in rhs_non_cspecs
)
# Non-contracting dims of LHS to be gathered along the SP axis.
......@@ -656,6 +659,12 @@ class GemmPrimitive(BasePrimitive):
prefix = "GemmPrimitive_"
warnings.warn(
"Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now,"
" please turn off Shardy by exporting the environment variable"
" 'JAX_USE_SHARDY_PARTITIONER=0' if you experience any problems."
)
def _generate_operand_rules(name, ndim, cdims):
specs = []
ldims = tuple(i for i in range(ndim) if i not in cdims)
......
......@@ -26,6 +26,7 @@ from .module import LayerNorm, Softmax
from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
from ..attention import fused_attn
from ..attention import CPStrategy
from ..softmax import SoftmaxType
from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
......@@ -274,6 +275,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT
context_checkpoint_name: str = "context"
@nn.compact
......@@ -323,6 +325,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name,
)
elif self.qkv_layout.is_kvpacked():
......@@ -350,6 +353,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name,
)
elif self.qkv_layout.is_separate():
......@@ -372,6 +376,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name,
)
else:
......@@ -505,6 +510,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis.
context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING.
context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention.
Optimization parameters
......@@ -529,6 +535,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = ""
context_parallel_strategy: str = "DEFAULT"
context_checkpoint_name: str = "context"
@nn.compact
......@@ -648,6 +655,24 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor = self.scale_factor
del self.scale_factor
# case-insensitive mapping for context parallel strategy
cp_strategy_map = {
"DEFAULT": CPStrategy.DEFAULT,
"ALL_GATHER": CPStrategy.ALL_GATHER,
"ALLGATHER": CPStrategy.ALL_GATHER, # Alternative spelling
"RING": CPStrategy.RING,
}
strategy_key = self.context_parallel_strategy.upper()
if strategy_key in cp_strategy_map:
context_parallel_strategy = cp_strategy_map[strategy_key]
else:
valid_strategies = list(cp_strategy_map.keys())
raise ValueError(
f"Invalid context parallel strategy: {self.context_parallel_strategy}. "
f"Valid options are: {valid_strategies} (case insensitive)"
)
if not use_fused_attn:
# unfused attention only supports splitted query, key, value
if qkv_layout.is_qkvpacked():
......@@ -696,6 +721,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name,
)(
query,
......
......@@ -404,9 +404,6 @@ def fp8_autocast(
if fp8_recipe is None:
fp8_recipe = recipe.DelayedScaling()
if mesh_resource is None:
mesh_resource = MeshResource()
Config = DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
Config = BlockScalingQuantizeConfig
......
......@@ -286,7 +286,7 @@ class MeshResource:
cp_resource: str = None
_GLOBAL_MESH_RESOURCE = MeshResource()
_GLOBAL_MESH_RESOURCE = None
@contextmanager
......@@ -314,6 +314,11 @@ def global_mesh_resource() -> MeshResource:
Returns:
The current MeshResource instance
"""
assert _GLOBAL_MESH_RESOURCE is not None, (
"Global mesh resource is not set. Please set the MeshResource via a global_shard_guard"
" context. If you are not using multiple GPUs, you can use an empty MeshResource by"
" wrapping your program in 'with global_shard_guard(MeshResource()):'"
)
return _GLOBAL_MESH_RESOURCE
......
......@@ -4,6 +4,8 @@
"""Functions for CUDA Graphs support in FP8"""
from collections.abc import Iterable
import contextlib
import gc
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import torch
......@@ -58,6 +60,25 @@ def graph_pool_handle():
return _graph_pool_handle()
@contextlib.contextmanager
def _graph_context_wrapper(*args, **kwargs):
"""Wrapper around `torch.cuda.graph`.
This wrapper is a temporary workaround for a PyTorch bug:
automatic garbage collection can destroy a graph while another
graph is being captured, resulting in a CUDA error. See
https://github.com/pytorch/pytorch/pull/161037.
"""
gc_is_enabled = gc.isenabled()
if gc_is_enabled:
gc.disable()
with torch.cuda.graph(*args, **kwargs):
yield
if gc_is_enabled:
gc.enable()
def _make_graphed_callables(
callables: SingleOrTuple[Callable],
sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]],
......@@ -445,7 +466,7 @@ def _make_graphed_callables(
args = sample_args[per_callable_fwd_idx]
kwargs = sample_kwargs[per_callable_fwd_idx]
fwd_graph = fwd_graphs[per_callable_fwd_idx]
with torch.cuda.graph(fwd_graph, pool=mempool):
with _graph_context_wrapper(fwd_graph, pool=mempool):
outputs = func(*args, **kwargs)
flatten_outputs, spec = _tree_flatten(outputs)
per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs)
......@@ -483,7 +504,7 @@ def _make_graphed_callables(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
if is_training:
with torch.cuda.graph(bwd_graph, pool=mempool):
with _graph_context_wrapper(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
......@@ -548,7 +569,7 @@ def _make_graphed_callables(
per_callable_output_unflatten_spec = []
graph_id = 0
for func, args, kwargs, fwd_graph in zip(callables, sample_args, sample_kwargs, fwd_graphs):
with torch.cuda.graph(fwd_graph, pool=mempool):
with _graph_context_wrapper(fwd_graph, pool=mempool):
outputs = func(*args, **kwargs)
graph_callables[graph_id] = func
graph_id += 1
......@@ -570,7 +591,7 @@ def _make_graphed_callables(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs
)
if is_training:
with torch.cuda.graph(bwd_graph, pool=mempool):
with _graph_context_wrapper(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad),
......
......@@ -12,7 +12,6 @@ from typing import Any, Optional
import torch
from transformer_engine.pytorch.module.base import get_workspace
from ...cpp_extensions import general_gemm
from ...distributed import (
CudaRNGStatesTracker,
......@@ -20,18 +19,24 @@ from ...distributed import (
reduce_scatter_along_first_dim,
)
from ...fp8 import FP8GlobalStateManager, Recipe
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD
from ...module.base import (
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
get_dummy_wgrad,
get_workspace,
)
from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize, is_quantized_tensor
from ...utils import (
canonicalize_device,
canonicalize_dtype,
clear_tensor_data,
devices_match,
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize, is_quantized_tensor
def _wait_async(handle: Optional[Any]) -> None:
......@@ -73,7 +78,8 @@ class BasicLinear(BasicOperation):
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
meaningful.
meaningful. This is primarily intented to integrate with
Megatron-LM.
userbuffers_options, dict, optional
Options for overlapping tensor-parallel communication with
compute using Userbuffers. This feature is highly
......@@ -979,20 +985,22 @@ class BasicLinear(BasicOperation):
# Saved tensors from forward pass
(x_local, w) = ctx.saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = self._accumulate_into_main_grad
grad_weight = None
if ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(self.weight, "__fsdp_param__"):
self.weight.main_grad = self.weight.get_main_grad()
if not hasattr(self.weight, "main_grad"):
weight_param = self.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight = self.weight.main_grad.detach()
grad_weight = weight_param.main_grad.detach()
else:
accumulate_into_main_grad = False
......@@ -1019,6 +1027,17 @@ class BasicLinear(BasicOperation):
# Clear input tensor if possible
clear_tensor_data(x_local)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad:
grad_weight = None
weight_param = self.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [grad_weight]
......@@ -9,13 +9,10 @@ from typing import Optional
import torch
from transformer_engine.pytorch.ops.basic import BasicLinear, MakeExtraOutput
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...module.base import get_dummy_wgrad
from ...utils import clear_tensor_data
from ..basic import BasicLinear, MakeExtraOutput
from ..op import FusedOperation, FusibleOperation, OperationContext
class BackwardLinearAdd(FusedOperation):
......@@ -53,20 +50,22 @@ class BackwardLinearAdd(FusedOperation):
# Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"):
linear_op.weight.main_grad = linear_op.weight.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"):
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight = linear_op.weight.main_grad.detach()
grad_weight = weight_param.main_grad.detach()
else:
accumulate_into_main_grad = False
......@@ -92,12 +91,23 @@ class BackwardLinearAdd(FusedOperation):
grad_output_quantizer=linear_op_ctx.grad_output_quantizer,
grad_input_quantizer=linear_op_ctx.grad_input_quantizer,
)
if accumulate_into_main_grad:
grad_weight = None
# Clear input tensor if possible
clear_tensor_data(x_local)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad:
grad_weight = None
weight_param = linear_op.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [(grad_weight,), ()], [(), ()]
......
......@@ -9,13 +9,10 @@ from typing import Optional
import torch
from ..basic import BasicLinear, ConstantScale
from ..op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...module.base import get_dummy_wgrad
from ...utils import clear_tensor_data
from ..basic import BasicLinear, ConstantScale
from ..op import FusedOperation, FusibleOperation, OperationContext
class BackwardLinearScale(FusedOperation):
......@@ -54,20 +51,22 @@ class BackwardLinearScale(FusedOperation):
# Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion
# Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"):
linear_op.weight.main_grad = linear_op.weight.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"):
weight_param = linear_op.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight = linear_op.weight.main_grad.detach()
grad_weight = weight_param.main_grad.detach()
else:
accumulate_into_main_grad = False
......@@ -92,12 +91,23 @@ class BackwardLinearScale(FusedOperation):
grad_output_quantizer=linear_op_ctx.grad_output_quantizer,
grad_input_quantizer=linear_op_ctx.grad_input_quantizer,
)
if accumulate_into_main_grad:
grad_weight = None
# Clear input tensor if possible
clear_tensor_data(x_local)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad:
grad_weight = None
weight_param = linear_op.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [(), (grad_weight,)], [(), ()]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment