"docs/results/vscode:/vscode.git/clone" did not exist on "4b0547d963b8f44bc0dbcda19e346ccceeaf1148"
Commit 2389ed3f authored by yuguo's avatar yuguo
Browse files

Merge branch 'release_v2.7' of https://github.com/NVIDIA/TransformerEngine into release_v2.7

parents 87e3e56e 58c3ac80
[submodule "3rdparty/googletest"] [submodule "3rdparty/googletest"]
path = 3rdparty/googletest path = 3rdparty/googletest
url = https://github.com/google/googletest.git url = https://github.com/google/googletest.git
[submodule "3rdparty/cudnn-frontend"]
path = 3rdparty/cudnn-frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
[submodule "3rdparty/hipify_torch"] [submodule "3rdparty/hipify_torch"]
path = 3rdparty/hipify_torch path = 3rdparty/hipify_torch
url = https://github.com/ROCm/hipify_torch.git url = https://github.com/ROCm/hipify_torch.git
...@@ -219,7 +219,9 @@ def train_and_evaluate(args): ...@@ -219,7 +219,9 @@ def train_and_evaluate(args):
else: else:
fp8_recipe = None fp8_recipe = None
with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe): with te.fp8_autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
):
encoder = Net(num_embed) encoder = Net(num_embed)
# We use nn.Embed, thus inputs need to be in int # We use nn.Embed, thus inputs need to be in int
inputs = jnp.zeros(input_shape, dtype=jnp.int32) inputs = jnp.zeros(input_shape, dtype=jnp.int32)
......
...@@ -193,7 +193,9 @@ def train_and_evaluate(args): ...@@ -193,7 +193,9 @@ def train_and_evaluate(args):
else: else:
fp8_recipe = None fp8_recipe = None
with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe): with te.fp8_autocast(
enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource()
):
cnn = Net(args.use_te) cnn = Net(args.use_te)
var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16)) var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
tx = optax.sgd(args.lr, args.momentum) tx = optax.sgd(args.lr, args.momentum)
......
...@@ -173,7 +173,7 @@ class TestDistributedLayernormMLP: ...@@ -173,7 +173,7 @@ class TestDistributedLayernormMLP:
) )
# Single GPU # Single GPU
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
single_jitter = jax.jit( single_jitter = jax.jit(
value_and_grad_func, value_and_grad_func,
static_argnums=range(len(inputs), len(static_inputs) + len(inputs)), static_argnums=range(len(inputs), len(static_inputs) + len(inputs)),
...@@ -330,7 +330,7 @@ class TestDistributedLayernormMLP: ...@@ -330,7 +330,7 @@ class TestDistributedLayernormMLP:
with use_jax_gemm(enabled=with_jax_gemm): with use_jax_gemm(enabled=with_jax_gemm):
# Single GPUs # Single GPUs
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
ln_mlp_single = LayerNormMLP( ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type, layernorm_type=layernorm_type,
intermediate_dim=INTERMEDIATE, intermediate_dim=INTERMEDIATE,
......
...@@ -28,6 +28,7 @@ from transformer_engine.jax.quantize import ( ...@@ -28,6 +28,7 @@ from transformer_engine.jax.quantize import (
is_fp8_available, is_fp8_available,
update_collections, update_collections,
) )
from transformer_engine.jax.sharding import MeshResource, global_shard_guard
@pytest.fixture(autouse=True, scope="function") @pytest.fixture(autouse=True, scope="function")
...@@ -490,19 +491,28 @@ class BaseTester: ...@@ -490,19 +491,28 @@ class BaseTester:
def test_forward(self, data_shape, dtype, attrs): def test_forward(self, data_shape, dtype, attrs):
"""Test normal datatype forward""" """Test normal datatype forward"""
QuantizeConfig.finalize() # Ensure FP8 disabled. QuantizeConfig.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_forward(data_shape, dtype) with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_forward(data_shape, dtype)
def test_backward(self, data_shape, dtype, attrs): def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward""" """Test normal datatype backward"""
QuantizeConfig.finalize() # Ensure FP8 disabled. QuantizeConfig.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_backward(data_shape, dtype) with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_backward(data_shape, dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test forward with fp8 enabled""" """Test forward with fp8 enabled"""
QuantizeConfig.initialize(fp8_recipe=fp8_recipe) QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
QuantizeConfig.finalize() QuantizeConfig.finalize()
@pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.skipif(not is_fp8_supported, reason=reason)
...@@ -510,7 +520,10 @@ class BaseTester: ...@@ -510,7 +520,10 @@ class BaseTester:
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test backward with fp8 enabled""" """Test backward with fp8 enabled"""
QuantizeConfig.initialize(fp8_recipe=fp8_recipe) QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
QuantizeConfig.finalize() QuantizeConfig.finalize()
......
...@@ -274,6 +274,8 @@ model_configs_mla = { ...@@ -274,6 +274,8 @@ model_configs_mla = {
"mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
"mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference
} }
......
...@@ -252,8 +252,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -252,8 +252,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 &&
cudnn_runtime_version >= 91100)) && cudnn_runtime_version >= 91100)) &&
// 9.11/9.12 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA // 9.11/9.12 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
(!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200) && is_training && (!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200 ||
sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 && cudnn_runtime_version == 91300) &&
is_training && sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 &&
!(head_dim_qk == 192 && head_dim_v == 128) && head_dim_qk != head_dim_v))) && !(head_dim_qk == 192 && head_dim_v == 128) && head_dim_qk != head_dim_v))) &&
// bias type // bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
......
...@@ -532,22 +532,22 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -532,22 +532,22 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
&epilogue, sizeof(epilogue))); &epilogue, sizeof(epilogue)));
if (counter != nullptr) { if (counter != nullptr) {
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000) #if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ", NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
CUDA_VERSION); CUDA_VERSION);
#endif #endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) #if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR( NVTE_ERROR(
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ", "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
CUBLAS_VERSION); CUBLAS_VERSION);
#endif #endif
#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \ #if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \
CUBLAS_VERSION < 130000 CUBLAS_VERSION < 130000
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA verson is ", "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ",
cuda::cudart_version()); cuda::cudart_version());
NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000, NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ", "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cublas_version()); cublas_version());
if (m_split == 0) m_split = 1; if (m_split == 0) m_split = 1;
if (n_split == 0) n_split = 1; if (n_split == 0) n_split = 1;
...@@ -783,20 +783,22 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor ...@@ -783,20 +783,22 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
// Check CUDA and cuBLAS versions // Check CUDA and cuBLAS versions
#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000) #if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000)
NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ", NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ",
CUDA_VERSION); CUDA_VERSION);
#endif #endif
#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) #if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000)
NVTE_ERROR("Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ", NVTE_ERROR(
CUBLAS_VERSION); "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
CUBLAS_VERSION);
#endif #endif
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, NVTE_CHECK(
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA verson is ", cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
cuda::cudart_version()); "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ",
cuda::cudart_version());
NVTE_CHECK( NVTE_CHECK(
cublas_version() >= 120205 && cublas_version() < 130000, cublas_version() >= 120205 && cublas_version() < 130000,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ", "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cublas_version()); cublas_version());
#endif #endif
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "common/common.h" #include "common/common.h"
#include "common/recipe/recipe_common.cuh" #include "common/recipe/recipe_common.cuh"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh" #include "common/util/ptx.cuh"
#include "common/utils.cuh" #include "common/utils.cuh"
...@@ -185,12 +184,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -185,12 +184,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
} }
} }
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
// store to global memory.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
// Step 3: Store cast output, Step 4: do transpose within thread tile // Step 3: Store cast output, Step 4: do transpose within thread tile
OVecCast tmp_output_c; OVecCast tmp_output_c;
...@@ -426,12 +419,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose ...@@ -426,12 +419,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
} }
} }
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
// store to global memory.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
// Step 3: Store cast output, Step 4: do transpose within thread tile // Step 3: Store cast output, Step 4: do transpose within thread tile
// Edge case: in the non-full tile case, there are three subcases // Edge case: in the non-full tile case, there are three subcases
// for full thread tile, it's the same thing here // for full thread tile, it's the same thing here
...@@ -939,15 +926,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -939,15 +926,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
#else #else
const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM); const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM);
const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM); const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM);
dim3 grid(num_blocks_x, num_blocks_y, 1);
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
cudaLaunchConfig_t cfg = {grid, THREADS_PER_BLOCK, 0, stream, NULL, 0};
if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >= 90) {
cfg.attrs = attribute;
cfg.numAttrs = 1;
}
#endif #endif
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
...@@ -962,6 +940,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -962,6 +940,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
dim3 grid(num_blocks_x, num_blocks_y, 1); dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0; const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0;
#else #else
dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = const bool full_tile =
row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0;
#endif #endif
...@@ -972,28 +951,26 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -972,28 +951,26 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
tensor_map_output_trans = tensor_map_output_trans =
get_tensor_map<OutputType>(output_t, num_rows, row_length); get_tensor_map<OutputType>(output_t, num_rows, row_length);
} }
cudaLaunchKernelEx(&cfg, block_scaled_cast_transpose_kernel<kReturnTranspose, float, InputType, OutputType>
block_scaled_cast_transpose_kernel<kReturnTranspose, float, <<<grid, THREADS_PER_BLOCK, 0, stream>>>(
InputType, OutputType>, reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
scale_stride_x, scale_stride_y, scale_t_stride_x, tensor_map_output_trans, pow_2_scale);
scale_t_stride_y, epsilon, tensor_map_output_trans, pow_2_scale);
} else { } else {
cudaLaunchKernelEx( block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType,
&cfg, OutputType>
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, <<<grid, THREADS_PER_BLOCK, 0, stream>>>(
InputType, OutputType>, reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, pow_2_scale);
pow_2_scale);
#else #else
while (true) { while (true) {
if (128 == block_len) { if (128 == block_len) {
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include "common/common.h" #include "common/common.h"
#include "common/recipe/recipe_common.cuh" #include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h" #include "common/transpose/cast_transpose.h"
#include "common/util/cuda_runtime.h"
#include "common/utils.cuh" #include "common/utils.cuh"
namespace transformer_engine { namespace transformer_engine {
...@@ -252,14 +251,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -252,14 +251,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
__syncthreads(); __syncthreads();
// If not return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return rowwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (!return_columnwise_gemm_ready && !return_columnwise_compact) {
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// Step 2: Cast and store to output_c // Step 2: Cast and store to output_c
if (return_rowwise) { if (return_rowwise) {
constexpr int r_stride = constexpr int r_stride =
...@@ -365,14 +356,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -365,14 +356,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
} }
} }
// If return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return columnwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (return_columnwise_gemm_ready || return_columnwise_compact) {
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t // Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t
if (return_columnwise_gemm_ready) { if (return_columnwise_gemm_ready) {
constexpr int c_stride = constexpr int c_stride =
...@@ -1448,12 +1431,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1448,12 +1431,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
#else #else
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim);
dim3 grid(num_blocks_x, num_blocks_y, 1);
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
#endif #endif
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType, input.dtype, InputType,
...@@ -1463,6 +1441,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1463,6 +1441,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
dim3 grid(num_blocks_x, num_blocks_y, 1); dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0; const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0;
#else #else
dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0;
#endif #endif
TRANSFORMER_ENGINE_SWITCH_CONDITION( TRANSFORMER_ENGINE_SWITCH_CONDITION(
...@@ -1532,34 +1511,25 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1532,34 +1511,25 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
} }
#else #else
size_t smem_bytes = kSMemSize * sizeof(InputType); size_t smem_bytes = kSMemSize * sizeof(InputType);
cudaLaunchConfig_t cfg = {grid, kThreadsPerBlock, smem_bytes, stream, NULL, 0};
if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >=
90) {
cfg.attrs = attribute;
cfg.numAttrs = 1;
}
// shared memory must be requested up // shared memory must be requested up
if (smem_bytes >= 48 * 1024) { if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute( cudaError_t err = cudaFuncSetAttribute(
&block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>, &block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size."); NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size.");
} cudaLaunchKernelEx(&cfg, } block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>
block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, <<<grid, kThreadsPerBlock, smem_bytes, stream>>>(
OutputType>, reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x,
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option,
scale_stride_x, scale_stride_y, scale_t_stride_x, columnwise_option, pow2_scale);
scale_t_stride_y, epsilon, rowwise_option, columnwise_option,
pow2_scale);
#endif #endif
) // kAligned ) // kAligned
) // OutputType ) // OutputType
) // InputType ) // InputType
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
......
...@@ -205,11 +205,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -205,11 +205,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Wait for the data to have arrived // Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], parity); ptx::mbarrier_wait_parity(&mbar[stage], parity);
// Trigger the next kernel, so its TMA load can be overlapped with the current kernel
if (stage == STAGES - 1) {
cudaTriggerProgrammaticLaunchCompletion();
}
float thread_amax = 0.0f; float thread_amax = 0.0f;
if constexpr (COLWISE_SCALING) { if constexpr (COLWISE_SCALING) {
const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise;
...@@ -1139,13 +1134,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1139,13 +1134,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
cudaLaunchConfig_t cfg = {grid, block_size, dshmem_size, stream, NULL, 0};
// This kernel will only be called on sm100+, so no need to check sm_arch
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1; cfg.attrs = attribute;
cfg.numAttrs = 1;
switch (scaling_type) { switch (scaling_type) {
case ScalingType::ROWWISE: case ScalingType::ROWWISE:
cudaFuncSetAttribute( cudaFuncSetAttribute(
...@@ -1153,13 +1141,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1153,13 +1141,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cudaLaunchKernelEx( cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
&cfg, false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, <<<grid, block_size, dshmem_size, stream>>>(
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); scale_stride_colwise);
break; break;
case ScalingType::COLWISE: case ScalingType::COLWISE:
cudaFuncSetAttribute( cudaFuncSetAttribute(
...@@ -1167,13 +1155,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1167,13 +1155,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cudaLaunchKernelEx( cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false,
&cfg, true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false, <<<grid, block_size, dshmem_size, stream>>>(
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); scale_stride_colwise);
break; break;
case ScalingType::BIDIMENSIONAL: case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute( cudaFuncSetAttribute(
...@@ -1181,13 +1169,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1181,13 +1169,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cudaLaunchKernelEx( cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, true,
&cfg, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, <<<grid, block_size, dshmem_size, stream>>>(
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); scale_stride_colwise);
break; break;
} }
......
...@@ -8,6 +8,7 @@ import operator ...@@ -8,6 +8,7 @@ import operator
from collections.abc import Iterable from collections.abc import Iterable
from typing import Tuple, Sequence, Union from typing import Tuple, Sequence, Union
from functools import partial, reduce from functools import partial, reduce
import warnings
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -34,6 +35,7 @@ from ..quantize import ( ...@@ -34,6 +35,7 @@ from ..quantize import (
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv, apply_padding_to_scale_inv,
) )
from ..sharding import global_mesh_resource
from .misc import get_padded_spec from .misc import get_padded_spec
...@@ -490,7 +492,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -490,7 +492,8 @@ class GemmPrimitive(BasePrimitive):
# Non-contracting dims of RHS always needs to be gathered along the FSDP axis # Non-contracting dims of RHS always needs to be gathered along the FSDP axis
rhs_non_cspecs = tuple( rhs_non_cspecs = tuple(
None if spec is not None and "fsdp" in spec else spec for spec in rhs_non_cspecs None if spec is not None and spec == global_mesh_resource().fsdp_resource else spec
for spec in rhs_non_cspecs
) )
# Non-contracting dims of LHS to be gathered along the SP axis. # Non-contracting dims of LHS to be gathered along the SP axis.
...@@ -656,6 +659,12 @@ class GemmPrimitive(BasePrimitive): ...@@ -656,6 +659,12 @@ class GemmPrimitive(BasePrimitive):
prefix = "GemmPrimitive_" prefix = "GemmPrimitive_"
warnings.warn(
"Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now,"
" please turn off Shardy by exporting the environment variable"
" 'JAX_USE_SHARDY_PARTITIONER=0' if you experience any problems."
)
def _generate_operand_rules(name, ndim, cdims): def _generate_operand_rules(name, ndim, cdims):
specs = [] specs = []
ldims = tuple(i for i in range(ndim) if i not in cdims) ldims = tuple(i for i in range(ndim) if i not in cdims)
......
...@@ -26,6 +26,7 @@ from .module import LayerNorm, Softmax ...@@ -26,6 +26,7 @@ from .module import LayerNorm, Softmax
from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor
from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type
from ..attention import fused_attn from ..attention import fused_attn
from ..attention import CPStrategy
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
from ..sharding import num_of_devices from ..sharding import num_of_devices
from ..sharding import get_sharding_map_logic_axis_to_mesh_axis from ..sharding import get_sharding_map_logic_axis_to_mesh_axis
...@@ -274,6 +275,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -274,6 +275,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq: Optional[int] = 1 max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = "" context_parallel_axis: str = ""
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT
context_checkpoint_name: str = "context" context_checkpoint_name: str = "context"
@nn.compact @nn.compact
...@@ -323,6 +325,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -323,6 +325,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq, max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name, context_checkpoint_name=self.context_checkpoint_name,
) )
elif self.qkv_layout.is_kvpacked(): elif self.qkv_layout.is_kvpacked():
...@@ -350,6 +353,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -350,6 +353,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq, max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name, context_checkpoint_name=self.context_checkpoint_name,
) )
elif self.qkv_layout.is_separate(): elif self.qkv_layout.is_separate():
...@@ -372,6 +376,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me ...@@ -372,6 +376,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me
max_segments_per_seq=self.max_segments_per_seq, max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=self.context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name, context_checkpoint_name=self.context_checkpoint_name,
) )
else: else:
...@@ -505,6 +510,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -505,6 +510,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
context_parallel_causal_load_balanced (bool): context_parallel_causal_load_balanced (bool):
Indicates the sequences are ordered for causal mask load balancing when running context parallelism. Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
context_parallel_axis (str): The name of the context parallel axis. context_parallel_axis (str): The name of the context parallel axis.
context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING.
context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention. context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention.
Optimization parameters Optimization parameters
...@@ -529,6 +535,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -529,6 +535,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq: Optional[int] = 1 max_segments_per_seq: Optional[int] = 1
context_parallel_causal_load_balanced: bool = False context_parallel_causal_load_balanced: bool = False
context_parallel_axis: str = "" context_parallel_axis: str = ""
context_parallel_strategy: str = "DEFAULT"
context_checkpoint_name: str = "context" context_checkpoint_name: str = "context"
@nn.compact @nn.compact
...@@ -648,6 +655,24 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -648,6 +655,24 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
scale_factor = self.scale_factor scale_factor = self.scale_factor
del self.scale_factor del self.scale_factor
# case-insensitive mapping for context parallel strategy
cp_strategy_map = {
"DEFAULT": CPStrategy.DEFAULT,
"ALL_GATHER": CPStrategy.ALL_GATHER,
"ALLGATHER": CPStrategy.ALL_GATHER, # Alternative spelling
"RING": CPStrategy.RING,
}
strategy_key = self.context_parallel_strategy.upper()
if strategy_key in cp_strategy_map:
context_parallel_strategy = cp_strategy_map[strategy_key]
else:
valid_strategies = list(cp_strategy_map.keys())
raise ValueError(
f"Invalid context parallel strategy: {self.context_parallel_strategy}. "
f"Valid options are: {valid_strategies} (case insensitive)"
)
if not use_fused_attn: if not use_fused_attn:
# unfused attention only supports splitted query, key, value # unfused attention only supports splitted query, key, value
if qkv_layout.is_qkvpacked(): if qkv_layout.is_qkvpacked():
...@@ -696,6 +721,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -696,6 +721,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
max_segments_per_seq=self.max_segments_per_seq, max_segments_per_seq=self.max_segments_per_seq,
context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced,
context_parallel_axis=self.context_parallel_axis, context_parallel_axis=self.context_parallel_axis,
context_parallel_strategy=context_parallel_strategy,
context_checkpoint_name=self.context_checkpoint_name, context_checkpoint_name=self.context_checkpoint_name,
)( )(
query, query,
......
...@@ -404,9 +404,6 @@ def fp8_autocast( ...@@ -404,9 +404,6 @@ def fp8_autocast(
if fp8_recipe is None: if fp8_recipe is None:
fp8_recipe = recipe.DelayedScaling() fp8_recipe = recipe.DelayedScaling()
if mesh_resource is None:
mesh_resource = MeshResource()
Config = DelayedScalingQuantizeConfig Config = DelayedScalingQuantizeConfig
if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): if isinstance(fp8_recipe, recipe.MXFP8BlockScaling):
Config = BlockScalingQuantizeConfig Config = BlockScalingQuantizeConfig
......
...@@ -286,7 +286,7 @@ class MeshResource: ...@@ -286,7 +286,7 @@ class MeshResource:
cp_resource: str = None cp_resource: str = None
_GLOBAL_MESH_RESOURCE = MeshResource() _GLOBAL_MESH_RESOURCE = None
@contextmanager @contextmanager
...@@ -314,6 +314,11 @@ def global_mesh_resource() -> MeshResource: ...@@ -314,6 +314,11 @@ def global_mesh_resource() -> MeshResource:
Returns: Returns:
The current MeshResource instance The current MeshResource instance
""" """
assert _GLOBAL_MESH_RESOURCE is not None, (
"Global mesh resource is not set. Please set the MeshResource via a global_shard_guard"
" context. If you are not using multiple GPUs, you can use an empty MeshResource by"
" wrapping your program in 'with global_shard_guard(MeshResource()):'"
)
return _GLOBAL_MESH_RESOURCE return _GLOBAL_MESH_RESOURCE
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
"""Functions for CUDA Graphs support in FP8""" """Functions for CUDA Graphs support in FP8"""
from collections.abc import Iterable from collections.abc import Iterable
import contextlib
import gc
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
import torch import torch
...@@ -58,6 +60,25 @@ def graph_pool_handle(): ...@@ -58,6 +60,25 @@ def graph_pool_handle():
return _graph_pool_handle() return _graph_pool_handle()
@contextlib.contextmanager
def _graph_context_wrapper(*args, **kwargs):
"""Wrapper around `torch.cuda.graph`.
This wrapper is a temporary workaround for a PyTorch bug:
automatic garbage collection can destroy a graph while another
graph is being captured, resulting in a CUDA error. See
https://github.com/pytorch/pytorch/pull/161037.
"""
gc_is_enabled = gc.isenabled()
if gc_is_enabled:
gc.disable()
with torch.cuda.graph(*args, **kwargs):
yield
if gc_is_enabled:
gc.enable()
def _make_graphed_callables( def _make_graphed_callables(
callables: SingleOrTuple[Callable], callables: SingleOrTuple[Callable],
sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]], sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]],
...@@ -445,7 +466,7 @@ def _make_graphed_callables( ...@@ -445,7 +466,7 @@ def _make_graphed_callables(
args = sample_args[per_callable_fwd_idx] args = sample_args[per_callable_fwd_idx]
kwargs = sample_kwargs[per_callable_fwd_idx] kwargs = sample_kwargs[per_callable_fwd_idx]
fwd_graph = fwd_graphs[per_callable_fwd_idx] fwd_graph = fwd_graphs[per_callable_fwd_idx]
with torch.cuda.graph(fwd_graph, pool=mempool): with _graph_context_wrapper(fwd_graph, pool=mempool):
outputs = func(*args, **kwargs) outputs = func(*args, **kwargs)
flatten_outputs, spec = _tree_flatten(outputs) flatten_outputs, spec = _tree_flatten(outputs)
per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs) per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs)
...@@ -483,7 +504,7 @@ def _make_graphed_callables( ...@@ -483,7 +504,7 @@ def _make_graphed_callables(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs torch.empty_like(o) if o.requires_grad else None for o in static_outputs
) )
if is_training: if is_training:
with torch.cuda.graph(bwd_graph, pool=mempool): with _graph_context_wrapper(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad( grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad), outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad),
...@@ -548,7 +569,7 @@ def _make_graphed_callables( ...@@ -548,7 +569,7 @@ def _make_graphed_callables(
per_callable_output_unflatten_spec = [] per_callable_output_unflatten_spec = []
graph_id = 0 graph_id = 0
for func, args, kwargs, fwd_graph in zip(callables, sample_args, sample_kwargs, fwd_graphs): for func, args, kwargs, fwd_graph in zip(callables, sample_args, sample_kwargs, fwd_graphs):
with torch.cuda.graph(fwd_graph, pool=mempool): with _graph_context_wrapper(fwd_graph, pool=mempool):
outputs = func(*args, **kwargs) outputs = func(*args, **kwargs)
graph_callables[graph_id] = func graph_callables[graph_id] = func
graph_id += 1 graph_id += 1
...@@ -570,7 +591,7 @@ def _make_graphed_callables( ...@@ -570,7 +591,7 @@ def _make_graphed_callables(
torch.empty_like(o) if o.requires_grad else None for o in static_outputs torch.empty_like(o) if o.requires_grad else None for o in static_outputs
) )
if is_training: if is_training:
with torch.cuda.graph(bwd_graph, pool=mempool): with _graph_context_wrapper(bwd_graph, pool=mempool):
grad_inputs = torch.autograd.grad( grad_inputs = torch.autograd.grad(
outputs=tuple(o for o in static_outputs if o.requires_grad), outputs=tuple(o for o in static_outputs if o.requires_grad),
inputs=tuple(i for i in static_input_surface if i.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad),
......
...@@ -12,7 +12,6 @@ from typing import Any, Optional ...@@ -12,7 +12,6 @@ from typing import Any, Optional
import torch import torch
from transformer_engine.pytorch.module.base import get_workspace
from ...cpp_extensions import general_gemm from ...cpp_extensions import general_gemm
from ...distributed import ( from ...distributed import (
CudaRNGStatesTracker, CudaRNGStatesTracker,
...@@ -20,18 +19,24 @@ from ...distributed import ( ...@@ -20,18 +19,24 @@ from ...distributed import (
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
) )
from ...fp8 import FP8GlobalStateManager, Recipe from ...fp8 import FP8GlobalStateManager, Recipe
from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...module.base import (
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
get_dummy_wgrad,
get_workspace,
)
from ...tensor import Quantizer from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer from ...tensor.float8_tensor import Float8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize, is_quantized_tensor
from ...utils import ( from ...utils import (
canonicalize_device, canonicalize_device,
canonicalize_dtype, canonicalize_dtype,
clear_tensor_data, clear_tensor_data,
devices_match, devices_match,
) )
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize, is_quantized_tensor
def _wait_async(handle: Optional[Any]) -> None: def _wait_async(handle: Optional[Any]) -> None:
...@@ -73,7 +78,8 @@ class BasicLinear(BasicOperation): ...@@ -73,7 +78,8 @@ class BasicLinear(BasicOperation):
weight's `main_grad` attribute instead of relying on PyTorch weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be there is no guarantee that `grad` will be set or be
meaningful. meaningful. This is primarily intented to integrate with
Megatron-LM.
userbuffers_options, dict, optional userbuffers_options, dict, optional
Options for overlapping tensor-parallel communication with Options for overlapping tensor-parallel communication with
compute using Userbuffers. This feature is highly compute using Userbuffers. This feature is highly
...@@ -979,20 +985,22 @@ class BasicLinear(BasicOperation): ...@@ -979,20 +985,22 @@ class BasicLinear(BasicOperation):
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local, w) = ctx.saved_tensors (x_local, w) = ctx.saved_tensors
# wgrad fusion # Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = self._accumulate_into_main_grad accumulate_into_main_grad = self._accumulate_into_main_grad
grad_weight = None grad_weight = None
if ctx.weight_requires_grad and accumulate_into_main_grad: if ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(self.weight, "__fsdp_param__"): weight_param = self.weight
self.weight.main_grad = self.weight.get_main_grad() if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(self.weight, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
"accumulate_into_main_grad=True, " "accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute" "but weight parameter does not have main_grad attribute"
) )
grad_weight = self.weight.main_grad.detach() grad_weight = weight_param.main_grad.detach()
else: else:
accumulate_into_main_grad = False accumulate_into_main_grad = False
...@@ -1019,6 +1027,17 @@ class BasicLinear(BasicOperation): ...@@ -1019,6 +1027,17 @@ class BasicLinear(BasicOperation):
# Clear input tensor if possible # Clear input tensor if possible
clear_tensor_data(x_local) clear_tensor_data(x_local)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad: if accumulate_into_main_grad:
grad_weight = None grad_weight = None
weight_param = self.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [grad_weight] return grad_input, [grad_weight]
...@@ -9,13 +9,10 @@ from typing import Optional ...@@ -9,13 +9,10 @@ from typing import Optional
import torch import torch
from transformer_engine.pytorch.ops.basic import BasicLinear, MakeExtraOutput from ...module.base import get_dummy_wgrad
from transformer_engine.pytorch.ops.op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...utils import clear_tensor_data from ...utils import clear_tensor_data
from ..basic import BasicLinear, MakeExtraOutput
from ..op import FusedOperation, FusibleOperation, OperationContext
class BackwardLinearAdd(FusedOperation): class BackwardLinearAdd(FusedOperation):
...@@ -53,20 +50,22 @@ class BackwardLinearAdd(FusedOperation): ...@@ -53,20 +50,22 @@ class BackwardLinearAdd(FusedOperation):
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors (x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion # Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = linear_op._accumulate_into_main_grad accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"): weight_param = linear_op.weight
linear_op.weight.main_grad = linear_op.weight.get_main_grad() if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
"accumulate_into_main_grad=True, " "accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute" "but weight parameter does not have main_grad attribute"
) )
grad_weight = linear_op.weight.main_grad.detach() grad_weight = weight_param.main_grad.detach()
else: else:
accumulate_into_main_grad = False accumulate_into_main_grad = False
...@@ -92,12 +91,23 @@ class BackwardLinearAdd(FusedOperation): ...@@ -92,12 +91,23 @@ class BackwardLinearAdd(FusedOperation):
grad_output_quantizer=linear_op_ctx.grad_output_quantizer, grad_output_quantizer=linear_op_ctx.grad_output_quantizer,
grad_input_quantizer=linear_op_ctx.grad_input_quantizer, grad_input_quantizer=linear_op_ctx.grad_input_quantizer,
) )
if accumulate_into_main_grad:
grad_weight = None
# Clear input tensor if possible # Clear input tensor if possible
clear_tensor_data(x_local) clear_tensor_data(x_local)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad:
grad_weight = None
weight_param = linear_op.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [(grad_weight,), ()], [(), ()] return grad_input, [(grad_weight,), ()], [(), ()]
......
...@@ -9,13 +9,10 @@ from typing import Optional ...@@ -9,13 +9,10 @@ from typing import Optional
import torch import torch
from ..basic import BasicLinear, ConstantScale from ...module.base import get_dummy_wgrad
from ..op import (
FusedOperation,
FusibleOperation,
OperationContext,
)
from ...utils import clear_tensor_data from ...utils import clear_tensor_data
from ..basic import BasicLinear, ConstantScale
from ..op import FusedOperation, FusibleOperation, OperationContext
class BackwardLinearScale(FusedOperation): class BackwardLinearScale(FusedOperation):
...@@ -54,20 +51,22 @@ class BackwardLinearScale(FusedOperation): ...@@ -54,20 +51,22 @@ class BackwardLinearScale(FusedOperation):
# Saved tensors from forward pass # Saved tensors from forward pass
(x_local, w) = linear_op_ctx.saved_tensors (x_local, w) = linear_op_ctx.saved_tensors
# wgrad fusion # Megatron-LM wgrad fusion
# Note: Get grad tensor from param so we can accumulate
# directly into it.
accumulate_into_main_grad = linear_op._accumulate_into_main_grad accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"): weight_param = linear_op.weight
linear_op.weight.main_grad = linear_op.weight.get_main_grad() if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"): if not hasattr(weight_param, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
"accumulate_into_main_grad=True, " "accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute" "but weight parameter does not have main_grad attribute"
) )
grad_weight = linear_op.weight.main_grad.detach() grad_weight = weight_param.main_grad.detach()
else: else:
accumulate_into_main_grad = False accumulate_into_main_grad = False
...@@ -92,12 +91,23 @@ class BackwardLinearScale(FusedOperation): ...@@ -92,12 +91,23 @@ class BackwardLinearScale(FusedOperation):
grad_output_quantizer=linear_op_ctx.grad_output_quantizer, grad_output_quantizer=linear_op_ctx.grad_output_quantizer,
grad_input_quantizer=linear_op_ctx.grad_input_quantizer, grad_input_quantizer=linear_op_ctx.grad_input_quantizer,
) )
if accumulate_into_main_grad:
grad_weight = None
# Clear input tensor if possible # Clear input tensor if possible
clear_tensor_data(x_local) clear_tensor_data(x_local)
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad:
grad_weight = None
weight_param = linear_op.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
return grad_input, [(), (grad_weight,)], [(), ()] return grad_input, [(), (grad_weight,)], [(), ()]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment