Commit ab3e5a92 authored by yuguo's avatar yuguo
Browse files

Merge commit '04c730c0' of...

Merge commit '04c730c0' of https://github.com/NVIDIA/TransformerEngine
parents a8d19fd9 04c730c0
...@@ -37,8 +37,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea ...@@ -37,8 +37,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
constexpr NVTETensor workspace = nullptr; constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr; constexpr const NVTETensor grad = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, nullptr, output, detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, output, dbias,
dbias, workspace, stream); workspace, nullptr, stream);
} }
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
...@@ -46,6 +46,18 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no ...@@ -46,6 +46,18 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no
NVTE_API_CALL(nvte_quantize_noop); NVTE_API_CALL(nvte_quantize_noop);
using namespace transformer_engine; using namespace transformer_engine;
// Create config with noop tensor
QuantizationConfig quant_config;
quant_config.noop_tensor = noop;
nvte_quantize_v2(input, output, reinterpret_cast<NVTEQuantizationConfig>(&quant_config), stream);
}
void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_v2);
using namespace transformer_engine;
constexpr bool IS_DBIAS = false; constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false; constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false; constexpr bool IS_ACT = false;
...@@ -53,8 +65,8 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no ...@@ -53,8 +65,8 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no
constexpr NVTETensor workspace = nullptr; constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr; constexpr const NVTETensor grad = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, noop, output, detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
dbias, workspace, stream); input, grad, output, dbias, workspace, quant_config, stream);
} }
void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
...@@ -68,7 +80,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d ...@@ -68,7 +80,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
constexpr const NVTETensor activation_input = nullptr; constexpr const NVTETensor activation_input = nullptr;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>( detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
activation_input, input, nullptr, output, dbias, workspace, stream); activation_input, input, output, dbias, workspace, nullptr, stream);
} }
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
...@@ -82,7 +94,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati ...@@ -82,7 +94,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati
constexpr bool IS_ACT = false; constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dgelu<fp32, fp32>>( detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dgelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream); activation_input, input, output, dbias, workspace, nullptr, stream);
} }
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
...@@ -96,7 +108,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati ...@@ -96,7 +108,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati
constexpr bool IS_ACT = false; constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsilu<fp32, fp32>>( detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsilu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream); activation_input, input, output, dbias, workspace, nullptr, stream);
} }
void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
...@@ -110,7 +122,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati ...@@ -110,7 +122,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati
constexpr bool IS_ACT = false; constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, drelu<fp32, fp32>>( detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, drelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream); activation_input, input, output, dbias, workspace, nullptr, stream);
} }
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
...@@ -124,7 +136,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat ...@@ -124,7 +136,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat
constexpr bool IS_ACT = false; constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dqgelu<fp32, fp32>>( detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dqgelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream); activation_input, input, output, dbias, workspace, nullptr, stream);
} }
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
...@@ -138,7 +150,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat ...@@ -138,7 +150,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
constexpr bool IS_ACT = false; constexpr bool IS_ACT = false;
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsrelu<fp32, fp32>>( detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsrelu<fp32, fp32>>(
activation_input, input, nullptr, output, dbias, workspace, stream); activation_input, input, output, dbias, workspace, nullptr, stream);
} }
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
......
...@@ -99,8 +99,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -99,8 +99,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr size_t in_mem = in_act_mem + in_gate_mem; constexpr size_t in_mem = in_act_mem + in_gate_mem;
constexpr size_t out_act_mem = buff_size_aligned_out; constexpr size_t out_act_mem = buff_size_aligned_out;
constexpr size_t out_gate_mem = buff_size_aligned_out;
constexpr size_t out_mem = out_act_mem + out_gate_mem;
// const size_t in_transaction_size = grad_mem + in_mem; // const size_t in_transaction_size = grad_mem + in_mem;
constexpr size_t in_transaction_size = buff_elems * sizeof(IType); constexpr size_t in_transaction_size = buff_elems * sizeof(IType);
...@@ -111,7 +109,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -111,7 +109,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType *in_gate_sh = reinterpret_cast<IType *>(dshmem + grad_mem + in_act_mem); IType *in_gate_sh = reinterpret_cast<IType *>(dshmem + grad_mem + in_act_mem);
OType *out_act_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem); OType *out_act_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem);
OType *out_gate_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_act_mem); OType *out_gate_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_act_mem);
// uint64_t *mbar = reinterpret_cast<uint64_t *>(dshmem + grad_mem + in_mem + out_mem);
const uint64_t *TMAP_grad_in = reinterpret_cast<const uint64_t *>(&tensor_map_grad); const uint64_t *TMAP_grad_in = reinterpret_cast<const uint64_t *>(&tensor_map_grad);
const uint64_t *TMAP_in_act = reinterpret_cast<const uint64_t *>(&tensor_map_input_act); const uint64_t *TMAP_in_act = reinterpret_cast<const uint64_t *>(&tensor_map_input_act);
...@@ -294,7 +291,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -294,7 +291,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1;
constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1;
constexpr bool COMPUTE_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING;
constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128
constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32
...@@ -839,8 +835,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -839,8 +835,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1;
size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
e8m0_t *const scales_rowwise_ptr = e8m0_t *const scales_rowwise_ptr =
USE_ROWWISE_SCALING ? reinterpret_cast<e8m0_t *>(output->scale_inv.dptr) : nullptr; USE_ROWWISE_SCALING ? reinterpret_cast<e8m0_t *>(output->scale_inv.dptr) : nullptr;
e8m0_t *const scales_colwise_ptr = e8m0_t *const scales_colwise_ptr =
......
...@@ -145,7 +145,6 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) ...@@ -145,7 +145,6 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM;
constexpr int transaction_size = shmem_buff_size * (IS_DACT ? 2 : 1);
const bool is_master_thread = (threadIdx.x == 0); const bool is_master_thread = (threadIdx.x == 0);
...@@ -518,7 +517,6 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) ...@@ -518,7 +517,6 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
__shared__ alignas(128) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; __shared__ alignas(128) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM;
constexpr int transaction_size = shmem_buff_size * (IS_DACT ? 2 : 1);
const bool is_master_thread = (threadIdx.x == 0); const bool is_master_thread = (threadIdx.x == 0);
...@@ -940,7 +938,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -940,7 +938,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
bool use_colwise_scaling = output->has_columnwise_data(); bool use_colwise_scaling = output->has_columnwise_data();
checkCuDriverContext(stream); checkCuDriverContext(stream);
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
const auto &input_shape = input.data.shape;
NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type.");
if (use_rowwise_scaling) { if (use_rowwise_scaling) {
...@@ -1250,9 +1247,9 @@ namespace detail { ...@@ -1250,9 +1247,9 @@ namespace detail {
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP, template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &)> float (*OP)(float, const ParamOP &)>
void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETensor noop, void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor output,
NVTETensor output, NVTETensor dbias, NVTETensor workspace, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) { const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
const Tensor *input_tensor; const Tensor *input_tensor;
const Tensor *activation_input_tensor; const Tensor *activation_input_tensor;
if constexpr (IS_DBIAS || IS_DACT) { if constexpr (IS_DBIAS || IS_DACT) {
...@@ -1267,6 +1264,12 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe ...@@ -1267,6 +1264,12 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe
auto output_tensor = reinterpret_cast<Tensor *>(output); auto output_tensor = reinterpret_cast<Tensor *>(output);
auto dbias_tensor = reinterpret_cast<Tensor *>(dbias); auto dbias_tensor = reinterpret_cast<Tensor *>(dbias);
auto workspace_tensor = reinterpret_cast<Tensor *>(workspace); auto workspace_tensor = reinterpret_cast<Tensor *>(workspace);
const QuantizationConfig *quant_config_cpp =
reinterpret_cast<const QuantizationConfig *>(quant_config);
// extract noop tensor from quant_config_cpp if it's not null
const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr;
const auto noop_tensor = noop != nullptr ? *(reinterpret_cast<const Tensor *>(noop)) : Tensor(); const auto noop_tensor = noop != nullptr ? *(reinterpret_cast<const Tensor *>(noop)) : Tensor();
switch (output_tensor->scaling_mode) { switch (output_tensor->scaling_mode) {
...@@ -1294,6 +1297,36 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe ...@@ -1294,6 +1297,36 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe
workspace_tensor, stream); workspace_tensor, stream);
break; break;
} }
case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D");
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f;
quantize_transpose_square_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream);
break;
}
case NVTE_BLOCK_SCALING_1D: {
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D");
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f;
FP8BlockwiseRowwiseOption rowwise_option = output_tensor->has_data()
? FP8BlockwiseRowwiseOption::ROWWISE
: FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option =
output_tensor->has_columnwise_data() ? FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE
: FP8BlockwiseColumnwiseOption::NONE;
quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv,
output_tensor->columnwise_scale_inv, output_tensor->data,
output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, stream);
break;
}
default: default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + ".");
} }
......
...@@ -59,7 +59,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -59,7 +59,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t scales_stride) { const size_t scales_stride) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1;
constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1;
constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128
constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32
...@@ -68,8 +67,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -68,8 +67,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128
constexpr size_t THREADS_PER_SCALE_X_ROWWISE = constexpr size_t THREADS_PER_SCALE_X_ROWWISE =
DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16
constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2
const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
...@@ -357,6 +355,7 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) ...@@ -357,6 +355,7 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream)
NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0");
} }
} else { } else {
// TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + ".");
} }
} }
......
...@@ -83,6 +83,29 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( ...@@ -83,6 +83,29 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared(
: "memory"); : "memory");
} }
__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) {
uint32_t waitComplete;
asm volatile(
"{\n\t .reg .pred P_OUT; \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P_OUT; \n"
"}"
: "=r"(waitComplete)
: "r"(mbar_ptr), "r"(parity)
: "memory");
return static_cast<bool>(waitComplete);
}
__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
while (!mbarrier_try_wait_parity(mbar_ptr, parity)) {
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// shared::cta -> global // shared::cta -> global
__device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr, __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr,
...@@ -106,30 +129,6 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( ...@@ -106,30 +129,6 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global(
: "memory"); : "memory");
} }
__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) {
uint32_t waitComplete;
asm volatile(
"{\n\t .reg .pred P_OUT; \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P_OUT; \n"
"}"
: "=r"(waitComplete)
: "r"(mbar_ptr), "r"(parity)
: "memory");
return static_cast<bool>(waitComplete);
}
__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
while (!mbarrier_try_wait_parity(mbar_ptr, parity)) {
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
__device__ __forceinline__ void cp_async_bulk_commit_group() {
asm volatile("cp.async.bulk.commit_group;");
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
__device__ __forceinline__ void cp_async_bulk_wait_group() { __device__ __forceinline__ void cp_async_bulk_wait_group() {
asm volatile("cp.async.bulk.wait_group 0;"); asm volatile("cp.async.bulk.wait_group 0;");
...@@ -158,13 +157,19 @@ __device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() { ...@@ -158,13 +157,19 @@ __device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() {
asm volatile("cp.async.bulk.wait_group.read 4;"); asm volatile("cp.async.bulk.wait_group.read 4;");
} }
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
__device__ __forceinline__ void cp_async_bulk_commit_group() {
asm volatile("cp.async.bulk.commit_group;");
}
// Proxy fence (bi-directional): // Proxy fence (bi-directional):
__device__ __forceinline__ void fence_proxy_async() { asm volatile("fence.proxy.async;"); } __device__ __forceinline__ void fence_proxy_async() { asm volatile("fence.proxy.async;"); }
__device__ __forceinline__ void fence_proxy_async_shared_cta() { __device__ __forceinline__ void fence_proxy_async_shared_cta() {
asm volatile("fence.proxy.async.shared::cta;"); asm volatile("fence.proxy.async.shared::cta;");
} }
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} // namespace ptx } // namespace ptx
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Top level package for numerical debugging."""
try:
from . import pytorch
from .pytorch.debug_state import set_weight_tensor_tp_group_reduce
except ImportError as e:
pass
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
This file contains DebugQuantizer and DebugQuantizedTensor objects,
which are wrappers over Quantizer and QuantizedTensor.
These wrappers add logic related to debugging, using the nvdlfw_inspect package.
"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable, Union
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
prepare_for_saving,
restore_from_saved,
)
aten = torch.ops.aten
_tensor_to_gemm_names_map = {
"weight": ["fprop", "dgrad"],
"activation": ["fprop", "wgrad"],
"output": ["fprop", None],
"gradient": ["dgrad", "wgrad"],
"wgrad": ["wgrad", None],
"dgrad": ["dgrad", None],
}
API_CALL_MODIFY = "modify_tensor()"
STANDARD_FP8_QUANTIZE = "FP8 Quantize"
HIGH_PRECISION = "High Precision"
class DebugQuantizer(Quantizer):
"""
DebugQuantizer is a Quantizer object used for debugging with nvidia-dlframework-inspect.
It allows adding custom calls inside the quantization process - which enables modifying tensors
or gathering tensor stats.
"""
def __init__(
self,
layer_name: str,
tensor_name: str,
parent_quantizer: Optional[Quantizer],
tp_group: torch.distributed.ProcessGroup,
):
import nvdlfw_inspect.api as debug_api
super().__init__(rowwise=True, columnwise=True)
self.layer_name = layer_name
self.tensor_name = tensor_name
self.parent_quantizer = parent_quantizer
self.tp_group = tp_group # used in inspect_tensor calls
self.iteration = debug_api.DEBUG_MANAGER._trainer_iteration_count
self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name]
# The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled,
# rowwise_tensor_plan, and columnwise_tensor_plan are computed.
# These fields indicate the path where API calls will be inserted.
#
# inspect_tensor*_enabled are bool fields,
# indicating whether some feature will need to run inspect_tensor_* calls.
#
# *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, HIGH_PRECISION]
# determining what will happen when the quantizer is used for that tensor.
self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"]
if self.output_tensor:
self.inspect_tensor_enabled, self.rowwise_tensor_plan = (
self.get_plans_for_output_tensors()
)
else:
(
self.inspect_tensor_enabled,
self.inspect_tensor_postquantize_enabled_rowwise,
self.inspect_tensor_postquantize_enabled_columnwise,
) = self.get_enabled_look_at_tensors()
self.rowwise_tensor_plan, self.columnwise_tensor_plan = self.get_tensors_plan()
self.log_messages_about_plans()
def get_plans_for_output_tensors(self) -> Tuple[bool, str]:
"""
Returns tuple (inspect_tensor_enabled: bool, plan: str). Plan is one of the
API_CALL_MODIFY or HIGH_PRECISION, because debug quantizer does not support
gemm output in FP8.
"""
import nvdlfw_inspect.api as debug_api
inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
)
modify_enabled = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
plan = API_CALL_MODIFY if modify_enabled else HIGH_PRECISION
return inspect_tensor_enabled, plan
def get_enabled_look_at_tensors(self):
"""
Returns a tuple of booleans determining which functions look_at_tensor_*(...) should be called.
"""
import nvdlfw_inspect.api as debug_api
inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
)
inspect_tensor_postquantize_enabled_rowwise = (
debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
gemm=self.rowwise_gemm_name,
)
)
inspect_tensor_postquantize_enabled_columnwise = (
debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
gemm=self.columnwise_gemm_name,
)
)
return (
inspect_tensor_enabled,
inspect_tensor_postquantize_enabled_rowwise,
inspect_tensor_postquantize_enabled_columnwise,
)
def get_tensors_plan(self):
"""
Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of
API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, or HIGH_PRECISION, indicating the behavior
of this quantizer with respect to these tensors.
"""
import nvdlfw_inspect.api as debug_api
rowwise_plan = None
columnwise_plan = None
modify_rowwise = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
if modify_rowwise:
rowwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
iteration=self.iteration,
)
if fp8_quantize:
rowwise_plan = STANDARD_FP8_QUANTIZE
if rowwise_plan is None:
rowwise_plan = HIGH_PRECISION
if self.columnwise_gemm_name is not None:
modify_columnwise = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
if modify_columnwise:
columnwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled(
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
iteration=self.iteration,
)
if fp8_quantize:
columnwise_plan = STANDARD_FP8_QUANTIZE
if columnwise_plan is None:
columnwise_plan = HIGH_PRECISION
return rowwise_plan, columnwise_plan
def log_messages_about_plans(self):
"""
Logs the messages about the plans for each of the tensors.
"""
import nvdlfw_inspect.api as debug_api
debug_api.log_message(
f"Tensor: {self.tensor_name}, gemm {self.rowwise_gemm_name} -"
f" {self.rowwise_tensor_plan}",
layer_name=self.layer_name,
extra_cachable_args=(self.rowwise_gemm_name, self.tensor_name),
)
debug_api.log_message(
f"Tensor: {self.tensor_name}, gemm {self.columnwise_gemm_name} -"
f" {self.columnwise_tensor_plan}",
layer_name=self.layer_name,
extra_cachable_args=(self.columnwise_gemm_name, self.tensor_name),
)
def _call_inspect_tensor_api(
self, tensor, rowwise_gemm_tensor=None, columnwise_gemm_tensor=None
):
import nvdlfw_inspect.api as debug_api
args = {
"layer_name": self.layer_name,
"tensor": tensor,
"tensor_name": self.tensor_name,
"iteration": debug_api.DEBUG_MANAGER._trainer_iteration_count,
"tp_group": self.tp_group,
}
if tensor is not None and self.inspect_tensor_enabled:
debug_api.transformer_engine.inspect_tensor(**args)
if self.output_tensor:
return
if (
self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_rowwise
):
args["tensor"] = rowwise_gemm_tensor
args["rowwise"] = True
debug_api.transformer_engine.inspect_tensor_postquantize(**args)
if (
self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_columnwise
):
args["tensor"] = columnwise_gemm_tensor
args["rowwise"] = False
debug_api.transformer_engine.inspect_tensor_postquantize(**args)
def quantize(
self,
tensor: torch.Tensor,
*,
out: Optional[Union[torch.Tensor, DebugQuantizedTensor]] = None,
dtype: torch.dtype = None,
):
"""Returns DebugQuantizedTensor object."""
import nvdlfw_inspect.api as debug_api
assert not self.output_tensor
if out is not None:
return self.update_quantized(tensor, self)
# 1. If there is fp8 quantization in at least one of the gemms,
# the quantization using the self.parent_quantizer is performed.
# rowwise gemm corresponds to the rowwise_usage in fp8, similarly with columnwise
rowwise_gemm_quantize = (
self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
columnwise_gemm_quantize = (
self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
if columnwise_gemm_quantize and not rowwise_gemm_quantize:
rowwise_gemm_quantize = True # only columnwise quantization not implemented
rowwise_gemm_tensor, columnwise_gemm_tensor = None, None
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
self.parent_quantizer.set_usage(
rowwise=True,
columnwise=columnwise_gemm_quantize, # columnwise usage only is not supported
)
quantized_tensor = self.parent_quantizer(tensor)
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8,
# one tensor with columnwise=True and rowwise=True is computed
# and both rowwise_tensor_plan and columnwise_tensor_plan point to it.
if self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE:
rowwise_gemm_tensor = quantized_tensor
if self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE:
columnwise_gemm_tensor = quantized_tensor
# 2. modify_tensor() is called, if it is used.
if self.columnwise_tensor_plan == API_CALL_MODIFY:
columnwise_gemm_tensor = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.columnwise_gemm_name,
tensor=tensor,
default_quantizer=self.parent_quantizer,
iteration=self.iteration,
dtype=dtype,
)
if columnwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call")
if self.rowwise_tensor_plan == API_CALL_MODIFY:
rowwise_gemm_tensor = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.rowwise_gemm_name,
tensor=tensor,
default_quantizer=self.parent_quantizer,
iteration=self.iteration,
dtype=dtype,
)
if rowwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call")
# 3. If some tensors still are not defined we use high precision tensor.
if self.rowwise_tensor_plan == HIGH_PRECISION:
rowwise_gemm_tensor = tensor.to(dtype)
if self.columnwise_tensor_plan == HIGH_PRECISION:
columnwise_gemm_tensor = tensor.to(dtype)
self._call_inspect_tensor_api(tensor, rowwise_gemm_tensor, columnwise_gemm_tensor)
# sometimes we may want to return simple tensor with only rowwise_gemm
if self.tensor_name in ["wgrad", "dgrad", "output"]:
return rowwise_gemm_tensor
return DebugQuantizedTensor(
rowwise_gemm_tensor=rowwise_gemm_tensor,
columnwise_gemm_tensor=columnwise_gemm_tensor,
quantizer=self,
layer_name=self.layer_name,
tensor_name=self.tensor_name,
)
def process_gemm_output(self, tensor: torch.Tensor):
"""This call is invoked after the gemm to inspect and modify the output tensor."""
import nvdlfw_inspect.api as debug_api
assert self.parent_quantizer is None, "FP8 output is not supported for debug=True."
assert self.output_tensor
tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"}
if self.rowwise_tensor_plan == API_CALL_MODIFY:
tensor = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
gemm=tensor_to_gemm[self.tensor_name],
tensor_name=self.tensor_name,
tensor=tensor,
iteration=self.iteration,
default_quantizer=self.parent_quantizer,
)
self._call_inspect_tensor_api(tensor)
return tensor
def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> QuantizedTensor:
"""Override make_empty() from Quantizer class."""
if self.parent_quantizer is not None:
return self.parent_quantizer.make_empty(shape, dtype=dtype, device=device)
return torch.empty(shape, dtype=dtype, device=device)
def calibrate(self, tensor: torch.Tensor):
"""Calibration override, should not be invoked."""
raise RuntimeError("[NVTORCH-INSPECT ERROR] Calibration with debug is not supported")
def update_quantized(
self,
src: torch.Tensor,
dst: QuantizedTensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
"""Update quantized tensor - used in weight caching."""
import nvdlfw_inspect.api as debug_api
assert noop_flag is None, "CUDA Graphs are not supported with debug=True!"
updated_rowwise_gemm = False
if self.parent_quantizer is not None:
if (
dst.rowwise_gemm_tensor is not None
and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
):
if hasattr(dst.rowwise_gemm_tensor, "quantize_"):
dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None)
else:
tex.quantize(src, self.parent_quantizer, dst.rowwise_gemm_tensor, None)
updated_rowwise_gemm = True
if (
dst.columnwise_gemm_tensor is not None
and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
and not updated_rowwise_gemm
):
if hasattr(dst.columnwise_gemm_tensor, "quantize_"):
dst.columnwise_gemm_tensor.quantize_(src, noop_flag=None)
else:
tex.quantize(src, self.parent_quantizer, dst.columnwise_gemm_tensor, None)
if self.columnwise_tensor_plan == API_CALL_MODIFY:
out = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.columnwise_gemm_name,
tensor=src,
default_quantizer=self.parent_quantizer,
out=dst.columnwise_gemm_tensor,
iteration=self.iteration,
)
assert out is None, (
"API call debug_api.transformer_engine.modify_tensor with out != None should"
" return None"
)
if self.rowwise_tensor_plan == API_CALL_MODIFY:
debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
gemm=self.rowwise_gemm_name,
tensor=src,
default_quantizer=self.parent_quantizer,
out=dst.rowwise_gemm_tensor,
iteration=self.iteration,
)
if self.rowwise_tensor_plan == HIGH_PRECISION:
dst.rowwise_gemm_tensor.copy_(src)
if self.columnwise_tensor_plan == HIGH_PRECISION:
# if they are the same tensor object, it is sufficient to update one
if dst.columnwise_gemm_tensor is not dst.rowwise_gemm_tensor:
dst.columnwise_gemm_tensor.copy_(src)
self._call_inspect_tensor_api(src, dst.rowwise_gemm_tensor, dst.columnwise_gemm_tensor)
def any_feature_enabled(self) -> bool:
"""Returns bool if there is at least one API call enabled."""
if self.output_tensor:
return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY
if (
self.inspect_tensor_enabled
or self.inspect_tensor_postquantize_enabled_rowwise
or self.inspect_tensor_postquantize_enabled_columnwise
or self.rowwise_tensor_plan == API_CALL_MODIFY
or self.columnwise_tensor_plan == API_CALL_MODIFY
):
return True
if self.parent_quantizer is not None:
if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
return False
class DebugQuantizedTensor:
"""
Class containing quantized tensors after debug. Depending on configuration
it can contain one or two different objects. These objects can be accessed by the method
get_tensor().
"""
def __init__(
self,
rowwise_gemm_tensor,
columnwise_gemm_tensor,
quantizer,
layer_name=None,
tensor_name=None,
):
self.rowwise_gemm_tensor = rowwise_gemm_tensor
self.columnwise_gemm_tensor = columnwise_gemm_tensor
self.quantizer = quantizer
self._layer_name = layer_name
self._tensor_name = tensor_name
def prepare_for_saving(self):
""" " Prepare for saving method override"""
self.tensors_to_save = (
[self.rowwise_gemm_tensor, self.columnwise_gemm_tensor]
if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor
else [self.rowwise_gemm_tensor]
)
tensor_list, tensor_objects_list = prepare_for_saving(*self.tensors_to_save)
self.tensors_to_save = tensor_objects_list
# pylint: disable=unbalanced-tuple-unpacking
return tensor_list, self
def restore_from_saved(self, tensors):
"""Restore from saved method override"""
tensor_objects_list, saved_tensors = restore_from_saved(
self.tensors_to_save,
tensors,
return_saved_tensors=True,
)
if len(tensor_objects_list) == 2:
# pylint: disable=unbalanced-tuple-unpacking
self.rowwise_gemm_tensor, self.columnwise_gemm_tensor = tensor_objects_list
else:
self.rowwise_gemm_tensor = tensor_objects_list[0]
self.columnwise_gemm_tensor = self.rowwise_gemm_tensor
return saved_tensors
def quantize_(self, tensor, *, noop_flag=None):
""" " quantize_ method override"""
assert noop_flag is None, "CUDA Graphs are not supported with debug=True!"
self.quantizer.update_quantized(tensor, self)
def dequantize(self, *, dtype=None):
""" " dequantize method override"""
if dtype is None:
dtype = self.rowwise_gemm_tensor.dtype
return self.rowwise_gemm_tensor.dequantize().to(dtype)
def get_tensor(self, transpose: bool):
"""Is used in the python gemm() to get tensor or transpose of the tensor."""
return self.rowwise_gemm_tensor if not transpose else self.columnwise_gemm_tensor
def size(self):
"""Size of the tensor."""
return self.rowwise_gemm_tensor.size()
def update_usage(self, rowwise_usage: bool, columnwise_usage: bool):
"""Update usage of the tensor."""
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Managing the state of all the debugged layers.
"""
import sys
class TEDebugState:
"""
A class to manage the state of debug layers.
"""
layer_count = 1
layers_initialized = {}
weight_tensor_tp_group_reduce = True
debug_enabled = None
@classmethod
def initialize(cls):
"""
If debug_api module is initialized, then sets cls.debug_enabled to True.
"""
if "nvdlfw_inspect" in sys.modules:
import nvdlfw_inspect.api as debug_api
if cls.debug_enabled is False and debug_api.DEBUG_MANAGER is not None:
# This method is invoked when initializing TE modules.
# If this error is thrown, it means that some TE module had been initialized before
# debug_api was initialized, and now a new TE module is being initialized.
# This is likely to be a bug.
raise RuntimeError(
"[nv_dlfw_inspect] nv_dlfw_inspect module should be initialized before"
" initialization of the first TE module"
)
cls.debug_enabled = debug_api.DEBUG_MANAGER is not None
@classmethod
def _reset(cls):
"""Resets layer count and stats buffers."""
from ..features.utils.stats_buffer import STATS_BUFFERS
STATS_BUFFERS.reset()
cls.debug_enabled = None
cls.layers_initialized.clear()
@classmethod
def get_layer_count(cls):
"""
Layer counter is used when layer names are not provided to modules by the user.
"""
lc = cls.layer_count
cls.layer_count += 1
return lc
@classmethod
def set_weight_tensor_tp_group_reduce(cls, enabled):
"""Sets weight tensor reduction mode."""
cls.weight_tensor_tp_group_reduce = enabled
def set_weight_tensor_tp_group_reduce(enabled):
"""Sets weight tensor reduction mode."""
TEDebugState.set_weight_tensor_tp_group_reduce(enabled)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utils functions for the debug module."""
def any_feature_enabled(quantizers):
"""Returns True if at least one API call is made from DebugQuantizer."""
return any(q.any_feature_enabled() for q in quantizers)
...@@ -83,7 +83,8 @@ _load_library() ...@@ -83,7 +83,8 @@ _load_library()
from . import flax from . import flax
from . import quantize from . import quantize
from .quantize import fp8_autocast from .quantize import fp8_autocast, update_collections, get_delayed_scaling
from .quantize import NVTE_FP8_COLLECTION_NAME
from .sharding import MeshResource from .sharding import MeshResource
from .sharding import MajorShardingType, ShardingResource, ShardingType from .sharding import MajorShardingType, ShardingResource, ShardingType
...@@ -101,11 +102,14 @@ ShardingResource = deprecate_wrapper( ...@@ -101,11 +102,14 @@ ShardingResource = deprecate_wrapper(
) )
__all__ = [ __all__ = [
"NVTE_FP8_COLLECTION_NAME",
"fp8_autocast", "fp8_autocast",
"update_collections",
"get_delayed_scaling",
"MeshResource", "MeshResource",
"MajorShardingType", "MajorShardingType",
"ShardingResource", "ShardingResource",
"ShardingType", "ShardingType",
"flax", "flax",
"praxis", "quantize",
] ]
...@@ -91,7 +91,6 @@ def _activation_bwd_rule(activation_type, ctx, g): ...@@ -91,7 +91,6 @@ def _activation_bwd_rule(activation_type, ctx, g):
(x, _) = ctx (x, _) = ctx
assert x.dtype == g.dtype assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type) dx = tex.dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape)
return (dx, None) return (dx, None)
......
...@@ -10,6 +10,7 @@ from packaging import version ...@@ -10,6 +10,7 @@ from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
import transformer_engine_jax import transformer_engine_jax
...@@ -26,12 +27,12 @@ from .misc import ( ...@@ -26,12 +27,12 @@ from .misc import (
should_apply_1x_fused_dbias_war_for_arch_l_100, should_apply_1x_fused_dbias_war_for_arch_l_100,
NamedSharding, NamedSharding,
) )
from .quantization import _jax_quantize_dbias, _jax_dbias, quantize_dbias from .quantization import _jax_dbias, _quantize_dbias_impl
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import ( from ..quantize import (
Quantizer, Quantizer,
QuantizeAxis, QuantizeLayout,
DelayedScaleQuantizer, DelayedScaleQuantizer,
ScalingMode, ScalingMode,
) )
...@@ -110,41 +111,31 @@ class ActLuPrimitive(BasePrimitive): ...@@ -110,41 +111,31 @@ class ActLuPrimitive(BasePrimitive):
""" """
te_act_lu_p abstract te_act_lu_p abstract
""" """
del act_enum, act_len, scale_shapes del act_enum, scale_shapes
dtype = dtypes.canonicalize_dtype(x_aval.dtype) dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32 assert scale_aval is None or scale_aval.dtype == jnp.float32
assert x_aval.shape[-2] == act_len, (
out_shape = ( "activation input should be replicated by act_len in the -2 axis, got input shape"
*x_aval.shape[:-2], f" {x_aval.shape} and act_len {act_len}"
1,
x_aval.shape[-1],
) )
out_shape = (*x_aval.shape[:-2], x_aval.shape[-1]) # Exclude act dim
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode scaling_mode
).get_scale_shape_2x(out_shape[:-2] + (out_shape[-1],), is_padded=not is_outer) ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1)
if not is_2x:
if len(rowwise_scale_inv_shape) > 1: out_shape = (1,)
rowwise_scale_inv_shape = ( colwise_scale_inv_shape = (1,)
rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:] colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
)
if len(colwise_scale_inv_shape) > 1:
colwise_scale_inv_shape = (
colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:]
)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype) shape=colwise_scale_inv_shape, dtype=scale_dtype
colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype) )
if is_2x:
colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval return out_aval, colwise_out_aval, scale_inv_aval, colwise_scale_inv_aval, updated_amax_aval
...@@ -172,7 +163,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -172,7 +163,7 @@ class ActLuPrimitive(BasePrimitive):
assert scale_aval is None or scale_aval.dtype == jnp.float32 assert scale_aval is None or scale_aval.dtype == jnp.float32
out = ffi.ffi_lowering(ActLuPrimitive.name)( out = ffi.ffi_lowering(ActLuPrimitive.name)(
ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode, is_2x=is_2x ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x
) )
return out return out
...@@ -211,15 +202,8 @@ class ActLuPrimitive(BasePrimitive): ...@@ -211,15 +202,8 @@ class ActLuPrimitive(BasePrimitive):
) )
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode scaling_mode
).get_scale_shape_2x(out.shape[:-2] + (out.shape[-1],), is_padded=False) ).get_scale_shape_2x(out.shape, is_padded=False, flatten_axis=-1)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: # Slice out padding for MXFP8, noop for DelayedScaling
rowwise_scale_inv_shape = (
rowwise_scale_inv_shape[:-1] + (1,) + rowwise_scale_inv_shape[-1:]
)
if is_2x:
colwise_scale_inv_shape = (
colwise_scale_inv_shape[:-1] + (1,) + colwise_scale_inv_shape[-1:]
)
scale_inv = jax.lax.slice( scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
) )
...@@ -227,6 +211,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -227,6 +211,7 @@ class ActLuPrimitive(BasePrimitive):
colwise_scale_inv = jax.lax.slice( colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
) )
return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax
@staticmethod @staticmethod
...@@ -292,11 +277,14 @@ class ActLuPrimitive(BasePrimitive): ...@@ -292,11 +277,14 @@ class ActLuPrimitive(BasePrimitive):
is_outer, is_outer,
) # Unused. ) # Unused.
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
out_spec = (*x_spec[:-2], None, x_spec[-2]) scale_spec = get_padded_spec(arg_infos[1])
out_spec = (*x_spec[:-2], x_spec[-1])
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x: if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec) colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
else: else:
colwise_out_spec = out_spec colwise_out_spec = out_spec
else: else:
...@@ -304,18 +292,24 @@ class ActLuPrimitive(BasePrimitive): ...@@ -304,18 +292,24 @@ class ActLuPrimitive(BasePrimitive):
colwise_out_sharding = NamedSharding( colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out" mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
if is_2x:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv" mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
) )
amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax") amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
colwise_scale_inv_sharding = NamedSharding(
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"ActLuPrimitive.colwise_scale_inv"
) )
return ( return (
out_sharding, out_sharding,
colwise_out_sharding, colwise_out_sharding,
...@@ -340,14 +334,14 @@ class ActLuPrimitive(BasePrimitive): ...@@ -340,14 +334,14 @@ class ActLuPrimitive(BasePrimitive):
): ):
del result_infos, is_outer # Unused. del result_infos, is_outer # Unused.
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
out_spec = (*x_spec[:-1], x_spec[-1]) scale_spec = get_padded_spec(arg_infos[1])
if act_len == 2 and x_spec[-1] is None:
# Ensure last axis is partitioned and not the gating axis out_spec = (*x_spec[:-2], x_spec[-1])
out_spec = (*x_spec[:-2], None, x_spec[-2])
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out")
if is_2x: if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(out_spec) colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1)
else: else:
colwise_out_spec = out_spec colwise_out_spec = out_spec
else: else:
...@@ -355,21 +349,25 @@ class ActLuPrimitive(BasePrimitive): ...@@ -355,21 +349,25 @@ class ActLuPrimitive(BasePrimitive):
colwise_out_sharding = NamedSharding( colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out" mesh, PartitionSpec(*colwise_out_spec), desc="ActLuPrimitive.colwise_out"
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
if is_2x:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="ActLuPrimitive.scale_inv" mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
) )
amax_sharding = scale_inv_sharding.duplicate_with_new_description("ActLuPrimitive.amax") amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
colwise_scale_inv_sharding = NamedSharding(
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"ActLuPrimitive.colwise_scale_inv"
) )
arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
arg_shardings[0] = NamedSharding(mesh, PartitionSpec(*out_spec)) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
arg_shardings = tuple(arg_shardings)
out_shardings = ( out_shardings = (
out_sharding, out_sharding,
colwise_out_sharding, colwise_out_sharding,
...@@ -394,7 +392,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -394,7 +392,7 @@ class ActLuPrimitive(BasePrimitive):
) )
) )
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else: else:
global_updated_amax = local_amax global_updated_amax = local_amax
...@@ -409,10 +407,59 @@ class ActLuPrimitive(BasePrimitive): ...@@ -409,10 +407,59 @@ class ActLuPrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
out_dtype,
act_enum,
act_len,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, act_enum, act_len, scale_dtype, scale_shapes, is_outer, mesh, result_types
x_rank = len(value_types[0].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank - 1, unique_var="i", flatten_axis=-2
)
x_axes = scale_rules.input_spec + (f"x{x_rank-1}",)
out = (*x_axes[:-2], x_axes[-1])
scale_inv = scale_rules.rowwise_rule
colwise_scale_inv = scale_rules.colwise_rule
if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(
multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2)
)
else:
colwise_out = out
else:
colwise_out = ("j",)
colwise_scale_inv = ("k",)
# amax is always a unit tensor.
amax = ("l",)
return SdyShardingRule(
(
x_axes,
"…1",
),
(out, colwise_out, scale_inv, colwise_scale_inv, amax),
**scale_rules.factor_sizes,
)
register_primitive(ActLuPrimitive) register_primitive(ActLuPrimitive)
# TODO(Jeremy): replace is_2x with q_layout
class DActLuDBiasQuantizePrimitive(BasePrimitive): class DActLuDBiasQuantizePrimitive(BasePrimitive):
""" """
DActLu DBias Cast Transpose Primitive DActLu DBias Cast Transpose Primitive
...@@ -445,42 +492,41 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -445,42 +492,41 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
te_dact_dbias_quantize_p abstract te_dact_dbias_quantize_p abstract
""" """
del act_enum, scale_shapes del act_enum, scale_shapes
dtype = dtypes.canonicalize_dtype(dz_aval.dtype) dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dtype assert x_aval.dtype == dz_dtype
assert x_aval.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x_aval.shape} and act_len {act_len}"
)
assert scale_aval.dtype == jnp.float32 assert scale_aval.dtype == jnp.float32
ir_hidden_size = dz_aval.shape[-1] ir_hidden_size = dz_aval.shape[-1]
gi_hidden_size = x_aval.shape[-1] gi_hidden_size = act_len * x_aval.shape[-1]
assert act_len * ir_hidden_size == gi_hidden_size assert act_len * ir_hidden_size == gi_hidden_size
out_shape = x_aval.shape out_shape = x_aval.shape
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype)
dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
if is_2x: if is_2x:
# Don't transpose output for MXFP8 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
t_shape = out_shape
else: else:
t_shape = multidim_transpose(out_shape) colwise_out_shape = out_shape
colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype) else:
colwise_scale_inv_aval = jax.core.ShapedArray( colwise_out_shape = (1,)
shape=colwise_scale_inv_shape, dtype=scale_dtype colwise_scale_inv_shape = (1,)
) colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
if is_dbias: if is_dbias:
dbias_shape = gi_hidden_size dbias_shape = (act_len, ir_hidden_size)
dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype)
(wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes( (wkspace_info,) = transformer_engine_jax.get_dact_dbias_quantize_workspace_sizes(
x_aval.size // gi_hidden_size, x_aval.size // gi_hidden_size,
gi_hidden_size, gi_hidden_size,
...@@ -489,9 +535,14 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -489,9 +535,14 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode, scaling_mode,
is_2x, is_2x,
) )
wkspace_aval = x_aval.update( wkspace_shape = wkspace_info[0]
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
) else:
dbias_shape = (1,)
wkspace_shape = (1,)
wkspace_dtype = jnp.float32
dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dz_dtype)
wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
return ( return (
out_aval, out_aval,
...@@ -543,7 +594,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -543,7 +594,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
dz, dz,
x, x,
scale, scale,
scaling_mode=scaling_mode, scaling_mode=scaling_mode.value,
is_2x=is_2x, is_2x=is_2x,
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=int(act_enum), act_enum=int(act_enum),
...@@ -587,23 +638,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -587,23 +638,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
) )
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode scaling_mode
).get_scale_shape_2x(x.shape, is_padded=False) ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=-2)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: # Slice out padding for MXFP8, noop for DelayedScaling
scale_inv = jax.lax.slice( scale_inv = jax.lax.slice(
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
)
if is_2x:
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
) )
if is_2x: return out, colwise_out, scale_inv, colwise_scale_inv, updated_amax, dbias
colwise_scale_inv = jax.lax.slice(
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
return (
out,
colwise_out,
scale_inv,
colwise_scale_inv,
updated_amax,
dbias,
) # Exclude wkspace
@staticmethod @staticmethod
def batcher( def batcher(
...@@ -670,15 +714,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -670,15 +714,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
result_infos, result_infos,
): ):
del out_dtype, result_infos, act_enum del out_dtype, result_infos, act_enum
del scale_dtype, scale_shapes, is_dbias, act_len, is_outer del scale_dtype, scale_shapes, act_len, is_outer
x_spec = get_padded_spec(arg_infos[1]) x_spec = get_padded_spec(arg_infos[1])
scale_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding( out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out" mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
) )
if is_2x: if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec) colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
else: else:
colwise_x_spec = x_spec colwise_x_spec = x_spec
else: else:
...@@ -687,23 +732,32 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -687,23 +732,32 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out" mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
) )
dbias_shaprding = NamedSharding( dbias_spec = x_spec[-2:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(x_spec[-1]), PartitionSpec(*dbias_spec),
desc="DActLuDBiasQuantizePrimitive.dbias", desc="DActLuDBiasQuantizePrimitive.dbias",
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if is_2x:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv" mesh, PartitionSpec(*scale_inv_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv"
) )
amax_sharding = NamedSharding( amax_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax" mesh, PartitionSpec(*amax_spec), desc="DActLuDBiasQuantizePrimitive.amax"
) )
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: colwise_scale_inv_sharding = NamedSharding(
scale_inv_sharding = NamedSharding( mesh,
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv" PartitionSpec(*colwise_scale_inv_spec),
) desc="DActLuDBiasQuantizePrimitive.colwise_scale_inv",
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DActLuDBiasQuantizePrimitive.colwise_scale_inv"
) )
return ( return (
out_sharding, out_sharding,
...@@ -711,7 +765,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -711,7 +765,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scale_inv_sharding, scale_inv_sharding,
colwise_scale_inv_sharding, colwise_scale_inv_sharding,
amax_sharding, amax_sharding,
dbias_shaprding, dbias_sharding,
) )
@staticmethod @staticmethod
...@@ -731,10 +785,15 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -731,10 +785,15 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
): ):
del result_infos, is_outer del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[1]) x_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec), desc="out") scale_spec = get_padded_spec(arg_infos[2])
out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
)
if is_2x: if is_2x:
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_x_spec = multidim_transpose(x_spec) colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2)
else: else:
colwise_x_spec = x_spec colwise_x_spec = x_spec
else: else:
...@@ -743,38 +802,39 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -743,38 +802,39 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out" mesh, PartitionSpec(*colwise_x_spec), desc="DActLuDBiasQuantizePrimitive.colwise_out"
) )
dbias_shaprding = NamedSharding( dbias_spec = x_spec[-2:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(x_spec[-1]), PartitionSpec(*dbias_spec),
desc="DActLuDBiasQuantizePrimitive.dbias", desc="DActLuDBiasQuantizePrimitive.dbias",
) )
scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if is_2x:
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.scale_inv" mesh, PartitionSpec(*scale_inv_spec), desc="ActLuPrimitive.scale_inv"
)
amax_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="DActLuDBiasQuantizePrimitive.amax"
) )
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="ActLuPrimitive.amax")
scale_inv_sharding = NamedSharding( colwise_scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.scale_inv" mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DActLuDBiasQuantizePrimitive.colwise_scale_inv"
) )
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
arg_shardings = (
arg_shardings[1],
arg_shardings[1],
*arg_shardings[2:],
) # dz and x are the same
out_shardings = ( out_shardings = (
out_sharding, out_sharding,
colwise_out_sharding, colwise_out_sharding,
scale_inv_sharding, scale_inv_sharding,
colwise_scale_inv_sharding, colwise_scale_inv_sharding,
amax_sharding, amax_sharding,
dbias_shaprding, dbias_sharding,
) )
def sharded_impl(dz, x, scale): def sharded_impl(dz, x, scale):
...@@ -799,7 +859,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -799,7 +859,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
else: else:
global_dbias = local_dbias global_dbias = local_dbias
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else: else:
global_updated_amax = local_amax global_updated_amax = local_amax
...@@ -808,6 +868,46 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -808,6 +868,46 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, scale_dtype, scale_shapes, act_enum, act_len, is_outer, mesh, result_types
x_rank = len(value_types[1].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank, unique_var="i", flatten_axis=-2
)
x_axes = scale_rules.input_spec
out = x_axes
if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
else:
colwise_out = tuple(x_axes)
else:
colwise_out = ("j",)
dbias = x_axes[-2:] if is_dbias else ("k",)
amax = ("…4",)
return SdyShardingRule(
(("…0",), tuple(x_axes), ("…2",)),
(out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias),
**scale_rules.factor_sizes,
)
register_primitive(DActLuDBiasQuantizePrimitive) register_primitive(DActLuDBiasQuantizePrimitive)
...@@ -816,14 +916,21 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S ...@@ -816,14 +916,21 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S
""" """
JAX native activation implementation JAX native activation implementation
""" """
x = jnp.split(inputs, len(activation_type), axis=-1) act_len = len(activation_type)
assert inputs.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {inputs.shape} and act_len {act_len}"
)
x = jnp.split(inputs, act_len, axis=-2)
acts = [] acts = []
for idx, act_fn in enumerate(activation_type): for idx, act_fn in enumerate(activation_type):
x_i = _convert_to_activation_function(act_fn)(x[idx]) x_i = _convert_to_activation_function(act_fn)(x[idx])
acts.append(x_i) acts.append(x_i)
x = reduce(operator.mul, acts) x = reduce(operator.mul, acts)
x = jnp.squeeze(x, axis=-2)
if quantizer: if quantizer:
return quantizer.quantize(x) return quantizer.quantize(x, flatten_axis=-1)
return x return x
...@@ -837,6 +944,12 @@ def _jax_quantize_dact_dbias( ...@@ -837,6 +944,12 @@ def _jax_quantize_dact_dbias(
""" """
JAX implementation of dact_lu and dbias with optional quantization JAX implementation of dact_lu and dbias with optional quantization
""" """
act_len = len(activation_type)
assert x.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x.shape} and act_len {act_len}"
)
_, vjp_func = jax.vjp( _, vjp_func = jax.vjp(
partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
) )
...@@ -844,10 +957,10 @@ def _jax_quantize_dact_dbias( ...@@ -844,10 +957,10 @@ def _jax_quantize_dact_dbias(
dbias = None dbias = None
if is_dbias: if is_dbias:
dbias = _jax_dbias(dx).astype(x.dtype) dbias = _jax_dbias(dx, dtype=x.dtype, flatten_axis=-2)
if quantizer is not None: if quantizer is not None:
dx = quantizer.quantize(dx, dq_dtype=x.dtype) dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
else: else:
dx = dx.astype(x.dtype) dx = dx.astype(x.dtype)
...@@ -863,6 +976,7 @@ def act_lu( ...@@ -863,6 +976,7 @@ def act_lu(
Args: Args:
x: Input tensor to be processed. x: Input tensor to be processed.
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function to apply. activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
...@@ -873,12 +987,17 @@ def act_lu( ...@@ -873,12 +987,17 @@ def act_lu(
A ScaledTensor containing the quantized activated input. A ScaledTensor containing the quantized activated input.
""" """
act_type_id = ActivationEnum[activation_type].value act_type_id = ActivationEnum[activation_type].value
act_len = len(activation_type)
assert x.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x.shape} and act_len {act_len}"
)
if not ActLuPrimitive.enabled(): if not ActLuPrimitive.enabled():
return _jax_act_lu(x, activation_type, quantizer) return _jax_act_lu(x, activation_type, quantizer)
# TE/common does not support colwise-only quantization yet # TE/common does not support colwise-only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_act_lu(x, activation_type, quantizer) return _jax_act_lu(x, activation_type, quantizer)
# TE/common does not support 2x quantization for DelayedScaling yet # TE/common does not support 2x quantization for DelayedScaling yet
...@@ -889,17 +1008,16 @@ def act_lu( ...@@ -889,17 +1008,16 @@ def act_lu(
return war_output return war_output
scale = jnp.empty((1,), jnp.float32) scale = jnp.empty((1,), jnp.float32)
output_shape = (*x.shape[:-1], x.shape[-1] // len(activation_type)) output_shape = (*x.shape[:-2], x.shape[-1])
if quantizer is None: if quantizer is None:
x = x.reshape((-1, len(activation_type), x.shape[-1] // len(activation_type)))
out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind( out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind(
x, x,
scale, scale,
out_dtype=x.dtype, out_dtype=x.dtype,
act_enum=act_type_id, act_enum=act_type_id,
act_len=len(activation_type), act_len=act_len,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, is_2x=False,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
scale_shapes=((), ()), scale_shapes=((), ()),
...@@ -911,7 +1029,6 @@ def act_lu( ...@@ -911,7 +1029,6 @@ def act_lu(
if isinstance(quantizer, DelayedScaleQuantizer): if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale scale = quantizer.scale
x = x.reshape((*x.shape[:-1], len(activation_type), x.shape[-1] // len(activation_type)))
( (
rowwise_casted_output, rowwise_casted_output,
colwise_casted_output, colwise_casted_output,
...@@ -923,25 +1040,15 @@ def act_lu( ...@@ -923,25 +1040,15 @@ def act_lu(
scale, scale,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
act_enum=act_type_id, act_enum=act_type_id,
act_len=len(activation_type), act_len=act_len,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(), is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(output_shape), # output does not have act axis
scale_shapes=quantizer.get_scale_shapes(output_shape, flatten_axis=-1),
is_outer=True, is_outer=True,
) )
rowwise_casted_output = rowwise_casted_output.reshape(output_shape)
if len(rowwise_scale_inv.shape) > 1:
rowwise_scale_inv = jnp.squeeze(rowwise_scale_inv, axis=-2) # Remove act axis
if quantizer.q_axis in (QuantizeAxis.COLWISE, QuantizeAxis.ROWWISE_COLWISE):
colwise_output_shape = output_shape
if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
colwise_output_shape = multidim_transpose(output_shape)
colwise_casted_output = colwise_casted_output.reshape(colwise_output_shape)
if len(colwise_scale_inv.shape) > 1:
colwise_scale_inv = jnp.squeeze(colwise_scale_inv, axis=-2) # Remove act axis
quantizer.update(updated_amax) quantizer.update(updated_amax)
return ScaledTensorFactory.create( return ScaledTensorFactory.create(
...@@ -951,8 +1058,8 @@ def act_lu( ...@@ -951,8 +1058,8 @@ def act_lu(
colwise_scale_inv=colwise_scale_inv, colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype, dq_dtype=x.dtype,
q_axis=quantizer.q_axis, q_layout=quantizer.q_layout,
layout=quantizer.get_layout(), data_layout=quantizer.get_data_layout(),
) )
...@@ -968,7 +1075,7 @@ def quantize_dact_dbias( ...@@ -968,7 +1075,7 @@ def quantize_dact_dbias(
Args: Args:
dz: Gradient of the output with respect to the activation output. dz: Gradient of the output with respect to the activation output.
x: Input tensor that was processed by the forward pass. x: Input tensor that was processed by the forward pass.
Shape: (..., ACT_DIM * K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
is_dbias: If True, compute bias gradient. Defaults to True. is_dbias: If True, compute bias gradient. Defaults to True.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
...@@ -979,21 +1086,25 @@ def quantize_dact_dbias( ...@@ -979,21 +1086,25 @@ def quantize_dact_dbias(
- The gradient of the activation with respect to the bias. - The gradient of the activation with respect to the bias.
""" """
act_len = len(activation_type)
assert x.shape[-2] == act_len, (
"activation input should be replicated by act_len in the -2 axis, got input shape"
f" {x.shape} and act_len {act_len}"
)
if not DActLuDBiasQuantizePrimitive.enabled(): if not DActLuDBiasQuantizePrimitive.enabled():
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
# TE/common does not support colwise-only quantization yet # TE/common does not support colwise-only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer)
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out, _ = quantize_dact_dbias( out = dact_lu(dz, x, activation_type, quantizer=None)
dz=dz, x=x, activation_type=activation_type, is_dbias=False, quantizer=None return _quantize_dbias_impl(out, quantizer, is_dbias=True, flatten_axis=-2)
)
return quantize_dbias(out, is_dbias=True, quantizer=quantizer)
is_gated = len(activation_type) == 2 is_gated = act_len == 2
# TE/common does not support DelayedScaling2x for gated-act yet # TE/common does not support DelayedScaling2x for gated-act yet
if is_gated: if is_gated:
war_output = try_apply_delayed_scaling_2x_war( war_output = try_apply_delayed_scaling_2x_war(
...@@ -1003,6 +1114,7 @@ def quantize_dact_dbias( ...@@ -1003,6 +1114,7 @@ def quantize_dact_dbias(
activation_type=activation_type, activation_type=activation_type,
is_dbias=is_dbias, is_dbias=is_dbias,
quantizer=quantizer, quantizer=quantizer,
flatten_axis=-2,
) )
if war_output is not None: if war_output is not None:
return war_output return war_output
...@@ -1019,18 +1131,18 @@ def quantize_dact_dbias( ...@@ -1019,18 +1131,18 @@ def quantize_dact_dbias(
# outputs float32 for dbias accumulation # outputs float32 for dbias accumulation
out_dtype=(jnp.float32 if is_dbias else x.dtype), out_dtype=(jnp.float32 if is_dbias else x.dtype),
# default value for no scaling, TE/common ignore this value when scale is unset # default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, # unused is_2x=False, # unused
scale_dtype=jnp.float32, # unused scale_dtype=jnp.float32, # unused
scale_shapes=((), ()), # unused scale_shapes=((), ()), # unused
is_dbias=False, is_dbias=False,
act_enum=act_type_id, act_enum=act_type_id,
act_len=len(activation_type), act_len=act_len,
is_outer=True, is_outer=True,
) )
dbias = None dbias = None
if is_dbias: if is_dbias:
dbias = _jax_dbias(output).astype(x.dtype) dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
return output.astype(x.dtype), dbias return output.astype(x.dtype), dbias
if isinstance(quantizer, DelayedScaleQuantizer): if isinstance(quantizer, DelayedScaleQuantizer):
...@@ -1041,16 +1153,9 @@ def quantize_dact_dbias( ...@@ -1041,16 +1153,9 @@ def quantize_dact_dbias(
dgated = dact_lu( dgated = dact_lu(
dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type
) )
# TODO(Jeremy): Debug - TE's quantize_dbias produced nans in this case for distributed layernorm_mlp tests out, dbias = _quantize_dbias_impl(
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
out, dbias = _jax_quantize_dbias(dgated, quantizer=quantizer, dq_dtype=x.dtype) )
else:
out, dbias = quantize_dbias(
dgated,
quantizer=quantizer,
is_dbias=True,
dq_dtype=x.dtype,
)
return out, dbias return out, dbias
out_shape = x.shape out_shape = x.shape
...@@ -1070,15 +1175,16 @@ def quantize_dact_dbias( ...@@ -1070,15 +1175,16 @@ def quantize_dact_dbias(
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(), is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(out_shape), # output has act axis
scale_shapes=quantizer.get_scale_shapes(out_shape, flatten_axis=-2),
is_dbias=is_dbias, is_dbias=is_dbias,
act_enum=act_type_id, act_enum=act_type_id,
act_len=len(activation_type), act_len=act_len,
is_outer=True, is_outer=True,
) )
# For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv colwise_scale_inv = rowwise_scale_inv
quantizer.update(updated_amax) quantizer.update(updated_amax)
...@@ -1090,8 +1196,9 @@ def quantize_dact_dbias( ...@@ -1090,8 +1196,9 @@ def quantize_dact_dbias(
colwise_scale_inv=colwise_scale_inv, colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype, dq_dtype=x.dtype,
q_axis=quantizer.q_axis, q_layout=quantizer.q_layout,
layout=quantizer.get_layout(), data_layout=quantizer.get_data_layout(),
flatten_axis=-2, # as output has act axis
) )
return out, dbias return out, dbias
......
...@@ -14,6 +14,7 @@ import jax ...@@ -14,6 +14,7 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes, lax from jax import dtypes, lax
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax.experimental.custom_partitioning import SdyShardingRule
import transformer_engine_jax import transformer_engine_jax
from transformer_engine_jax import NVTE_Fused_Attn_Backend from transformer_engine_jax import NVTE_Fused_Attn_Backend
...@@ -42,6 +43,7 @@ from ..sharding import ( ...@@ -42,6 +43,7 @@ from ..sharding import (
get_mesh_axis_rank, get_mesh_axis_rank,
get_all_mesh_axes, get_all_mesh_axes,
num_of_devices, num_of_devices,
with_sharding_constraint,
) )
...@@ -618,6 +620,35 @@ class FusedAttnFwdPrimitive(BasePrimitive): ...@@ -618,6 +620,35 @@ class FusedAttnFwdPrimitive(BasePrimitive):
impl = partial(FusedAttnFwdPrimitive.impl, config=config) impl = partial(FusedAttnFwdPrimitive.impl, config=config)
return mesh, impl, out_shardings, arg_shardings return mesh, impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(config, mesh, value_types, result_types):
del mesh, result_types
# Keep in sync with `infer_sharding_from_operands`.
# We only need the first input. Fill up the rest with placeholders.
input_spec = [(f"…{x}",) for x in range(len(value_types))]
# The RNG state sharding cannot be expressed as a Shardy rule. We use with_sharding_constraint
# instead. This has to happen outside of the primitive, see `fused_attn_fwd`.
rng_sharding = (f"…{len(value_types)}",)
if config.qkv_layout.is_qkvpacked():
input_spec[0] = ("…0", "seqlen", "three", "head", "hidden")
elif config.qkv_layout.is_kvpacked() or config.qkv_layout.is_separate():
input_spec[0] = ("…0", "seqlen", "head", "hidden")
else:
raise ValueError(f"Unsupported {config.qkv_layout=}")
is_packed_softmax = get_cudnn_version() >= (9, 6, 0) and config.qkv_layout.is_thd()
out_sharding = ("…0", "seqlen", "head", "hidden")
if is_packed_softmax:
softmax_aux_sharding = ("…0", "seqlen", "head", "i")
else:
softmax_aux_sharding = ("…0", "head", "seqlen", "i")
return SdyShardingRule(
tuple(input_spec), (out_sharding, softmax_aux_sharding, rng_sharding)
)
register_primitive(FusedAttnFwdPrimitive) register_primitive(FusedAttnFwdPrimitive)
...@@ -998,6 +1029,15 @@ class FusedAttnBwdPrimitive(BasePrimitive): ...@@ -998,6 +1029,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(config, mesh, value_types, result_types):
del config, mesh
# We only care about the four first arguments.
# Keep in sync with `infer_sharding_from_operands`.
input_spec = tuple((f"…{x}",) for x in range(len(value_types)))
output_spec = tuple((f"…{x}",) for x in range(len(result_types)))
return SdyShardingRule(input_spec, output_spec)
register_primitive(FusedAttnBwdPrimitive) register_primitive(FusedAttnBwdPrimitive)
...@@ -2436,13 +2476,15 @@ def fused_attn_fwd( ...@@ -2436,13 +2476,15 @@ def fused_attn_fwd(
primitive = FusedRingAttnFwdPrimitive.outer_primitive primitive = FusedRingAttnFwdPrimitive.outer_primitive
seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor)
return primitive.bind( output, softmax_aux, rng_state = primitive.bind(
*qkv_for_primitive, *qkv_for_primitive,
bias, bias,
seed, seed,
*seq_desc_flatten, *seq_desc_flatten,
config=fused_config, config=fused_config,
) )
rng_state = with_sharding_constraint(rng_state, PartitionSpec(get_all_mesh_axes(), None))
return (output, softmax_aux, rng_state)
def fused_attn_bwd( def fused_attn_bwd(
......
...@@ -98,6 +98,15 @@ class BasePrimitive(metaclass=ABCMeta): ...@@ -98,6 +98,15 @@ class BasePrimitive(metaclass=ABCMeta):
""" """
return NotImplemented return NotImplemented
@staticmethod
@abstractmethod
def shardy_sharding_rule(*args):
"""
Returns the sharding rule for this primitive.
"""
del args
return "... -> ..."
def register_primitive(cls): def register_primitive(cls):
""" """
...@@ -123,7 +132,9 @@ def register_primitive(cls): ...@@ -123,7 +132,9 @@ def register_primitive(cls):
batching.primitive_batchers[outer_p] = cls.batcher batching.primitive_batchers[outer_p] = cls.batcher
outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
outer_p_lower.def_partition( outer_p_lower.def_partition(
infer_sharding_from_operands=cls.infer_sharding_from_operands, partition=cls.partition infer_sharding_from_operands=cls.infer_sharding_from_operands,
partition=cls.partition,
sharding_rule=cls.shardy_sharding_rule,
) )
mlir.register_lowering( mlir.register_lowering(
outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results) outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
from typing import Tuple, Sequence, Union, Dict, List from typing import Tuple, Sequence, Union, Dict, List
from functools import partial, reduce from functools import partial, reduce
import operator import operator
from transformer_engine_jax import get_device_compute_capability
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import get_device_compute_capability
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
...@@ -41,32 +41,45 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -41,32 +41,45 @@ class GroupedGemmPrimitive(BasePrimitive):
name = "te_grouped_gemm_ffi" name = "te_grouped_gemm_ffi"
multiple_results = True multiple_results = True
impl_static_args = (6, 7, 8, 9) impl_static_args = ()
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
@staticmethod @staticmethod
def abstract( def abstract(*args, num_gemms, scaling_mode, out_dtype, has_bias):
lhs_contig_aval, """
lhs_scale_contig_aval, Args:
rhs_contig_aval, *args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias:
rhs_scale_contig_aval, args[ 0 : num_gemms] are the lhs tensors,
bias_contig_aval, args[ num_gemms : 2*num_gemms] are the rhs tensors,
dim_list_aval, args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors,
*, args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors,
num_gemms, args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True.
scaling_mode, num_gemms: Number of GEMM operations to perform.
out_dtype, scaling_mode: Scaling mode for the GEMM operations.
out_flat_size, out_dtype: Data type of the output tensors.
): has_bias: Boolean indicating if bias tensors are provided.
del lhs_contig_aval, lhs_scale_contig_aval
del rhs_contig_aval, rhs_scale_contig_aval Returns:
del bias_contig_aval, dim_list_aval A tuple of ShapedArray objects of size num_gemms+1:
del num_gemms, scaling_mode ret[0 : num_gemms]: GEMM output tensors,
out_flat_aval = jax.core.ShapedArray(shape=(out_flat_size,), dtype=out_dtype) ret[num_gemms]:workspace tensor.
wkspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams """
wkspace_aval = jax.core.ShapedArray(shape=(wkspace_size,), dtype=jnp.uint8) del scaling_mode
return (out_flat_aval, wkspace_aval) expected_num_args = 5 * num_gemms if has_bias else 4 * num_gemms
assert (
len(args) == expected_num_args
), f"Expected {expected_num_args} input arguments, but got {len(args)}"
A_list = args[0:num_gemms]
B_list = args[num_gemms : 2 * num_gemms]
# A and B have shapes [1, m, k] and [1, n, k]
out_list_aval = tuple(
jax.core.ShapedArray((A.shape[1], B.shape[1]), dtype=out_dtype)
for A, B in zip(A_list, B_list)
)
workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams
workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
return (*out_list_aval, workspace_aval)
@staticmethod @staticmethod
def outer_abstract(*args, **kwargs): def outer_abstract(*args, **kwargs):
...@@ -74,60 +87,27 @@ class GroupedGemmPrimitive(BasePrimitive): ...@@ -74,60 +87,27 @@ class GroupedGemmPrimitive(BasePrimitive):
return out_aval return out_aval
@staticmethod @staticmethod
def lowering( def lowering(ctx, *args, num_gemms, scaling_mode, out_dtype, has_bias):
ctx, del out_dtype
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
*,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
) -> jnp.ndarray:
del out_dtype, out_flat_size
return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)(
ctx, ctx,
lhs_contig, *args,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms=num_gemms, num_gemms=num_gemms,
scaling_mode=int(scaling_mode), scaling_mode=int(scaling_mode),
has_bias=has_bias,
) )
@staticmethod @staticmethod
def impl( def impl(*args, num_gemms, scaling_mode, out_dtype, has_bias):
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms,
scaling_mode,
out_dtype,
out_flat_size,
) -> jnp.ndarray:
assert GroupedGemmPrimitive.inner_primitive is not None assert GroupedGemmPrimitive.inner_primitive is not None
out = GroupedGemmPrimitive.inner_primitive.bind( out = GroupedGemmPrimitive.inner_primitive.bind(
lhs_contig, *args,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms=num_gemms, num_gemms=num_gemms,
scaling_mode=scaling_mode.value, scaling_mode=scaling_mode.value,
out_dtype=out_dtype, out_dtype=out_dtype,
out_flat_size=out_flat_size, has_bias=has_bias,
) )
return out[0] # out is [out_flat, wkspace], only return out_flat return out[:-1] # out is [out_list, wkspace], only return out_list
register_primitive(GroupedGemmPrimitive) register_primitive(GroupedGemmPrimitive)
...@@ -183,10 +163,9 @@ def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision): ...@@ -183,10 +163,9 @@ def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Reshape + Transpose # Reshape + Transpose
# [..., M, K] -> [B, M, K] # [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K] # [..., K, M] -> [B, M, K]
lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.layout == "N") lhs_3d = _shape_normalization(lhs_dq, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.layout == "T") rhs_3d = _shape_normalization(rhs_dq, rhs_dn, rhs.data_layout == "T")
# _shape_normalization ensures contracting_dims=2 and batch_dims=0
dim_nums = (((2,), (2,)), ((0,), (0,))) dim_nums = (((2,), (2,)), ((0,), (0,)))
out_3d = jax.lax.dot_general( out_3d = jax.lax.dot_general(
lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=lhs.dq_dtype
...@@ -199,13 +178,13 @@ def _jax_gemm_delayed_scaling_fp8( ...@@ -199,13 +178,13 @@ def _jax_gemm_delayed_scaling_fp8(
): ):
"""FP8 GEMM for XLA pattern match""" """FP8 GEMM for XLA pattern match"""
assert ( assert (
rhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING rhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
), "rhs does not have delayed tensor scaling mode" ), "rhs does not have delayed tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.layout == "T": if lhs.data_layout == "T":
lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract) lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract)
if rhs.layout == "T": if rhs.data_layout == "T":
rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract) rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract)
lhs_dn = (lhs_contract, lhs_batch) lhs_dn = (lhs_contract, lhs_batch)
...@@ -231,7 +210,7 @@ def _jax_gemm_mxfp8_1d( ...@@ -231,7 +210,7 @@ def _jax_gemm_mxfp8_1d(
JAX GEMM for MXFP8 via scaled_matmul JAX GEMM for MXFP8 via scaled_matmul
""" """
assert ( assert (
rhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING
), "rhs does not have MXFP8 1D scaling mode" ), "rhs does not have MXFP8 1D scaling mode"
from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper
...@@ -292,10 +271,10 @@ def _jax_gemm( ...@@ -292,10 +271,10 @@ def _jax_gemm(
def _jax_gemm_fp8_impl(lhs, rhs): def _jax_gemm_fp8_impl(lhs, rhs):
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums) return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums)
if lhs.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}")
...@@ -367,6 +346,7 @@ def swizzled_scale(scales): ...@@ -367,6 +346,7 @@ def swizzled_scale(scales):
rows, cols = scales.shape rows, cols = scales.shape
scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4) scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4)
scales = jnp.transpose(scales, (0, 3, 2, 1, 4)) scales = jnp.transpose(scales, (0, 3, 2, 1, 4))
scales = scales.reshape(rows, cols)
return scales return scales
...@@ -381,18 +361,12 @@ def grouped_gemm( ...@@ -381,18 +361,12 @@ def grouped_gemm(
len(lhs_list) == len(rhs_list) == len(contracting_dims_list) len(lhs_list) == len(rhs_list) == len(contracting_dims_list)
), "lhs_list, rhs_list, contracting_dims_list must have the same length" ), "lhs_list, rhs_list, contracting_dims_list must have the same length"
# Flatten inputs and save their shapes
num_gemms = len(lhs_list)
out_flat_size = 0
dims = []
lhs_contig_ = []
rhs_contig_ = []
lhs_scale_inv_contig_ = []
rhs_scale_inv_contig_ = []
bias_contig_ = []
out_offsets = []
remain_shape_list = []
num_gemms = len(lhs_list) num_gemms = len(lhs_list)
lhs_list_ = []
rhs_list_ = []
lhs_sinv_list_ = []
rhs_sinv_list_ = []
bias_list_ = []
for i in range(num_gemms): for i in range(num_gemms):
lhs = lhs_list[i] lhs = lhs_list[i]
rhs = rhs_list[i] rhs = rhs_list[i]
...@@ -403,20 +377,20 @@ def grouped_gemm( ...@@ -403,20 +377,20 @@ def grouped_gemm(
lhs_shape = lhs.data.shape lhs_shape = lhs.data.shape
rhs_shape = rhs.data.shape rhs_shape = rhs.data.shape
out_dtype = lhs.dq_dtype out_dtype = lhs.dq_dtype
# For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal layout # For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout
if lhs.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
assert not ( assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2 lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2" ), "FP8 GEMM does not support E5M2 * E5M2"
((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims ((lhs_contract_dim,), (rhs_contract_dim,)) = contracting_dims
if lhs.layout == "T": if lhs.data_layout == "T":
lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim lhs_contract_dim = (lhs_contract_dim - 1) % lhs.data.ndim
if rhs.layout == "T": if rhs.data_layout == "T":
rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim rhs_contract_dim = (rhs_contract_dim - 1) % rhs.data.ndim
dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ()) dim_nums = ((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())
else: else:
# For jnp.ndarray, only consider contracting_dims, layout is always NN # For jnp.ndarray, only consider contracting_dims, data_layout is always NN
scaling_mode = ScalingMode.NVTE_NO_SCALING scaling_mode = ScalingMode.NO_SCALING
lhs_shape = lhs.shape lhs_shape = lhs.shape
rhs_shape = rhs.shape rhs_shape = rhs.shape
out_dtype = lhs.dtype out_dtype = lhs.dtype
...@@ -428,24 +402,25 @@ def grouped_gemm( ...@@ -428,24 +402,25 @@ def grouped_gemm(
lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract) lhs_remain_shape = _calculate_remaining_shape(lhs_shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract) rhs_remain_shape = _calculate_remaining_shape(rhs_shape, rhs_contract)
if scaling_mode == ScalingMode.NVTE_NO_SCALING: # Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy
if scaling_mode == ScalingMode.NO_SCALING:
lhs_3d = _shape_normalization(lhs, lhs_dn) lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn) rhs_3d = _shape_normalization(rhs, rhs_dn)
elif scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: elif scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.layout == "N") lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.layout == "T") rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
elif scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: elif scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_3d = _shape_normalization(lhs.data, lhs_dn) lhs_3d = _shape_normalization(lhs.data, lhs_dn)
rhs_3d = _shape_normalization(rhs.data, rhs_dn) rhs_3d = _shape_normalization(rhs.data, rhs_dn)
lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn) lhs_scale_inv = _shape_normalization(lhs.scale_inv, lhs_dn)
rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn) rhs_scale_inv = _shape_normalization(rhs.scale_inv, rhs_dn)
# swizzled_scale requires a matrix
lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze()) lhs_scale_inv = swizzled_scale(lhs_scale_inv.squeeze())
rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze()) rhs_scale_inv = swizzled_scale(rhs_scale_inv.squeeze())
else: else:
raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}") raise NotImplementedError("Unsupported ScalingMode: {scaling_mode}")
# Note: if _shape_normalization() is updated to support non-TN, need to update here # Note: already_transposed doesn't matter for the output shape
# already_transposed doesn't matter for the output shape
# x.shape = [B, D1, D2] # x.shape = [B, D1, D2]
# contracting_dims = (2, ) --> output.shape = [1, B * D1, D2] # contracting_dims = (2, ) --> output.shape = [1, B * D1, D2]
# contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1] # contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1]
...@@ -456,61 +431,37 @@ def grouped_gemm( ...@@ -456,61 +431,37 @@ def grouped_gemm(
bn = rhs_remain_shape[0] bn = rhs_remain_shape[0]
kl = lhs_3d.shape[-1] kl = lhs_3d.shape[-1]
kr = rhs_3d.shape[-1] kr = rhs_3d.shape[-1]
remain_shape_list.append(((bm,), (bn,))) assert kl == kr, f"After shape normalization, contracting dim size mismatch: {kl} != {kr}"
assert kl == kr, f"lhs_3d.shape[-1] ({kl}) != rhs_3d.shape[-1] ({kr})" if (bm % 16 != 0) or (bn % 16 != 0) or (kl % 16 != 0):
k = kl print("grouped_gemm input pair {i} has invalid problem shape for lowering: ")
print(f"m = {bm}, n = {bn}, k = {kl}; ")
if (bm % 16 != 0) or (bn % 16 != 0) or (k % 16 != 0): print("cuBLAS requires the problem shapes being multiples of 16")
print(f"grouped_gemm input pair {i} has invalid problem shape for lowering: ") assert (bm % 16 == 0) and (bn % 16 == 0) and (kl % 16 == 0)
print(
f"m = {bm}, n = {bn}, k = {k}; cuBLAS requires the problem shapes being multiples" lhs_list_.append(lhs_3d)
" of 16" rhs_list_.append(rhs_3d)
) if scaling_mode == ScalingMode.NO_SCALING:
assert bm % 16 == 0 and bn % 16 == 0 and k % 16 == 0 lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
dims.append((bm, bn, k)) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
lhs_contig_.append(lhs_3d.reshape(-1)) lhs_sinv_list_.append(lhs.scale_inv)
rhs_contig_.append(rhs_3d.reshape(-1)) rhs_sinv_list_.append(rhs.scale_inv)
if scaling_mode == ScalingMode.NVTE_NO_SCALING: if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
lhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) lhs_sinv_list_.append(lhs_scale_inv)
rhs_scale_inv_contig_.append(jnp.ones(1, dtype=jnp.float32)) rhs_sinv_list_.append(rhs_scale_inv)
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING:
lhs_scale_inv_contig_.append(lhs.scale_inv.reshape(-1))
rhs_scale_inv_contig_.append(rhs.scale_inv.reshape(-1))
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
lhs_scale_inv_contig_.append(lhs_scale_inv.reshape(-1))
rhs_scale_inv_contig_.append(rhs_scale_inv.reshape(-1))
if bias_list is not None: if bias_list is not None:
bias_contig_.append(bias_list[i].reshape(-1)) bias_list_.append(bias_list[i])
out_flat_size += bm * bn
out_offsets.append(out_flat_size) out_list = GroupedGemmPrimitive.outer_primitive.bind(
*lhs_list_,
lhs_contig = jnp.concatenate(lhs_contig_) *rhs_list_,
rhs_contig = jnp.concatenate(rhs_contig_) *lhs_sinv_list_,
lhs_scale_inv_contig = jnp.concatenate(lhs_scale_inv_contig_) *rhs_sinv_list_,
rhs_scale_inv_contig = jnp.concatenate(rhs_scale_inv_contig_) *bias_list_,
bias_contig = jnp.empty(0) if bias_list is None else jnp.concatenate(bias_contig_)
dim_list = jnp.array(dims, dtype=jnp.int32)
# Perform batched GEMM on flattened inputs
out_contig = GroupedGemmPrimitive.outer_primitive.bind(
lhs_contig,
lhs_scale_inv_contig,
rhs_contig,
rhs_scale_inv_contig,
bias_contig,
dim_list,
num_gemms=num_gemms, num_gemms=num_gemms,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
out_dtype=out_dtype, out_dtype=out_dtype,
out_flat_size=out_flat_size, has_bias=1 if bias_list is not None else 0,
) )
# Split the output back into tensors return out_list
out_offsets = jnp.array(out_offsets)
out_flat_list = jnp.split(out_contig, out_offsets[:-1])
out_tensors = []
for out_flat, (lhs_remain_shape, rhs_remain_shape) in zip(out_flat_list, remain_shape_list):
out_tensors.append(out_flat.reshape(*lhs_remain_shape, *rhs_remain_shape))
return out_tensors
...@@ -19,7 +19,7 @@ from jax.interpreters.mlir import dtype_to_ir_type ...@@ -19,7 +19,7 @@ from jax.interpreters.mlir import dtype_to_ir_type
import transformer_engine_jax import transformer_engine_jax
from ..sharding import get_padded_spec as te_get_padded_spec from ..sharding import get_padded_spec as te_get_padded_spec
from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeAxis from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeLayout
TEDType = transformer_engine_jax.DType TEDType = transformer_engine_jax.DType
...@@ -107,37 +107,37 @@ def normalize_axis_boundary(axis, ndim): ...@@ -107,37 +107,37 @@ def normalize_axis_boundary(axis, ndim):
return axis if axis >= 0 else ndim + axis return axis if axis >= 0 else ndim + axis
def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis_boundary=-1): def multidim_transpose(shape, static_axis_boundary=-1, transpose_axis=-1):
""" """
te_cast_transpose_p multi-dims transpose te_cast_transpose_p multi-dims transpose
static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be
involved into transpose, -1 means all axes involve into transpose. involved into transpose, -1 means all axes involve into transpose.
transpose_axis_boundary: int, Indicate how to split multi-dimensions tensors to 2D matrix for transpose_axis: int, Indicate how to split multi-dimensions tensors to 2D matrix for
transpose. Note, transpose_axis_boundary should be greater than static_axis_boundary transpose. Note, transpose_axis should be greater than static_axis_boundary
examples: examples:
X in shape (dim0, dim1, dim2, dim3, dim4) X in shape (dim0, dim1, dim2, dim3, dim4)
static_axis_boundary == -1, transpose_axis_boundary == 2 static_axis_boundary == -1, transpose_axis == 2
Xt = (dim2, dim3, dim4, dim0, dim1) Xt = (dim2, dim3, dim4, dim0, dim1)
static_axis_boundary == 0, transpose_axis_boundary == 2 static_axis_boundary == 0, transpose_axis == 2
Xt = (dim0, dim2, dim3, dim4, dim1) Xt = (dim0, dim2, dim3, dim4, dim1)
static_axis_boundary == 0, transpose_axis_boundary == 3 static_axis_boundary == 0, transpose_axis == 3
Xt = (dim0, dim3, dim4, dim1. dim2) Xt = (dim0, dim3, dim4, dim1. dim2)
""" """
if static_axis_boundary < 0: if static_axis_boundary < 0:
static_axis_boundary = -1 # means no static axes static_axis_boundary = -1 # means no static axes
assert static_axis_boundary < len(shape) - 2 # at least 2 remaining for transpose. assert static_axis_boundary < len(shape) - 2 # at least 2 remaining for transpose.
transpose_start_idx = static_axis_boundary + 1 transpose_start_idx = static_axis_boundary + 1
transpose_axis_boundary = normalize_axis_boundary(transpose_axis_boundary, len(shape)) transpose_axis = normalize_axis_boundary(transpose_axis, len(shape))
assert transpose_start_idx < transpose_axis_boundary assert transpose_start_idx < transpose_axis
return ( return (
*shape[:transpose_start_idx], *shape[:transpose_start_idx],
*shape[transpose_axis_boundary:], *shape[transpose_axis:],
*shape[transpose_start_idx:transpose_axis_boundary], *shape[transpose_start_idx:transpose_axis],
) )
...@@ -195,13 +195,13 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant ...@@ -195,13 +195,13 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
break break
return ( return (
quantizer is not None quantizer is not None
and quantizer.q_axis == QuantizeAxis.ROWWISE and quantizer.q_layout == QuantizeLayout.ROWWISE
and arch_l_100 and arch_l_100
and is_dbias and is_dbias
) )
def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs): def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, **kwargs):
""" """
Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling. Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling.
It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result. It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result.
...@@ -216,7 +216,7 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs): ...@@ -216,7 +216,7 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
""" """
should_apply_war = ( should_apply_war = (
quantizer is not None quantizer is not None
and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
and quantizer.is_2x2x() and quantizer.is_2x2x()
) )
if not should_apply_war: if not should_apply_war:
...@@ -224,14 +224,19 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs): ...@@ -224,14 +224,19 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
# 2x is not supported by TE kernels for delayed scaling # 2x is not supported by TE kernels for delayed scaling
# so revert to 1x and transpose in JAX # so revert to 1x and transpose in JAX
quantizer.q_axis = QuantizeAxis.ROWWISE quantizer.q_layout = QuantizeLayout.ROWWISE
rowwise = f(*args, **kwargs, quantizer=quantizer) rowwise = f(*args, **kwargs, quantizer=quantizer)
other_outputs = None other_outputs = None
if isinstance(rowwise, tuple): if isinstance(rowwise, tuple):
other_outputs = rowwise[1:] other_outputs = rowwise[1:]
rowwise = rowwise[0] rowwise = rowwise[0]
quantizer.q_axis = QuantizeAxis.ROWWISE_COLWISE quantizer.q_layout = QuantizeLayout.ROWWISE_COLWISE
colwise_data = jnp.transpose(rowwise.data, (-1, *range(rowwise.data.ndim - 1))) if flatten_axis < 0:
flatten_axis += rowwise.data.ndim
assert 0 < flatten_axis < rowwise.data.ndim, "flatten_axis is out of bounds"
colwise_data = jnp.transpose(
rowwise.data, (*range(flatten_axis, rowwise.data.ndim), *range(flatten_axis))
)
output_2x = ScaledTensorFactory.create( output_2x = ScaledTensorFactory.create(
data=rowwise.data, data=rowwise.data,
scale_inv=rowwise.scale_inv, scale_inv=rowwise.scale_inv,
...@@ -239,8 +244,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs): ...@@ -239,8 +244,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
colwise_scale_inv=rowwise.scale_inv, colwise_scale_inv=rowwise.scale_inv,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode,
dq_dtype=rowwise.dq_dtype, dq_dtype=rowwise.dq_dtype,
q_axis=QuantizeAxis.ROWWISE_COLWISE, q_layout=QuantizeLayout.ROWWISE_COLWISE,
layout=quantizer.get_layout(), data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis,
) )
if other_outputs is not None: if other_outputs is not None:
return (output_2x,) + other_outputs return (output_2x,) + other_outputs
......
...@@ -12,6 +12,7 @@ from packaging import version ...@@ -12,6 +12,7 @@ from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
...@@ -30,7 +31,7 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a ...@@ -30,7 +31,7 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from ..quantize import ScaledTensor, ScaledTensorFactory from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import ( from ..quantize import (
Quantizer, Quantizer,
QuantizeAxis, QuantizeLayout,
DelayedScaleQuantizer, DelayedScaleQuantizer,
ScalingMode, ScalingMode,
) )
...@@ -63,6 +64,27 @@ def get_backward_sm_margin(): ...@@ -63,6 +64,27 @@ def get_backward_sm_margin():
return int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) return int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
@cache
def is_norm_fwd_cudnn_enabled(scaling_mode: ScalingMode) -> bool:
"""Retrieves whether CuDNN norm fwd is enabled."""
# MXFP8_1D_SCALING always uses CuDNN currently
return (
int(os.getenv("NVTE_NORM_FWD_USE_CUDNN", "0")) == 1
or scaling_mode == ScalingMode.MXFP8_1D_SCALING
)
@cache
def is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode: ScalingMode) -> bool:
"""Retrieves whether norm should compute `gamma += 1.0` for zero-centered gamma
in weight dtype as opposed to compute dtype."""
if not is_norm_fwd_cudnn_enabled(scaling_mode):
# If CuDNN is not enabled, we use the TE backend which uses the compute dtype not weight dtype
# Remove this when TE supports gamma += 1.0 in weight dtype
return False
return int(os.getenv("NVTE_ZERO_CENTERED_GAMMA_IN_WTYPE", "0")) == 1
class NormFwdPrimitive(BasePrimitive): class NormFwdPrimitive(BasePrimitive):
""" """
Layer Normalization Forward FP8 Primitive Layer Normalization Forward FP8 Primitive
...@@ -105,6 +127,26 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -105,6 +127,26 @@ class NormFwdPrimitive(BasePrimitive):
if norm_type == NVTE_Norm_Type.LayerNorm: if norm_type == NVTE_Norm_Type.LayerNorm:
assert gamma_aval.size == beta_aval.size assert gamma_aval.size == beta_aval.size
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_aval = mu_aval.update(shape=(1,))
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
colwise_out_shape = x_aval.shape if is_2x else (1,)
colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_shape = colwise_scale_inv_shape if is_2x else (1,)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
(wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes( (wkspace_info,) = transformer_engine_jax.get_norm_fwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size x_aval.size // gamma_aval.size, # batch size
gamma_aval.size, # hidden size gamma_aval.size, # hidden size
...@@ -112,33 +154,13 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -112,33 +154,13 @@ class NormFwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype(gamma_aval.dtype), # wtype jax_dtype_to_te_dtype(gamma_aval.dtype), # wtype
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(out_dtype),
norm_type, norm_type,
scaling_mode.value, scaling_mode,
zero_centered_gamma, zero_centered_gamma,
epsilon, epsilon,
get_forward_sm_margin(), get_forward_sm_margin(),
is_2x, is_2x,
) )
wkspace_aval = jax.core.ShapedArray(
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
if norm_type == NVTE_Norm_Type.RMSNorm:
mu_aval = mu_aval.update(shape=(1,))
rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x(
x_aval.shape, is_padded=not is_outer
)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
colwise_out_aval = jax.core.ShapedArray(
shape=x_aval.shape if is_2x else (1,), dtype=out_dtype
)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
wkspace_aval = x_aval.update(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
) )
...@@ -274,17 +296,17 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -274,17 +296,17 @@ class NormFwdPrimitive(BasePrimitive):
scale_shapes=scale_shapes, scale_shapes=scale_shapes,
is_outer=False, is_outer=False,
) )
rowwise_scale_inv_shape, colwise_scale_inv_shape = scaling_mode.get_scale_shape_2x( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
x.shape, is_padded=False scaling_mode
).get_scale_shape_2x(x.shape, is_padded=False)
# slice out padding for mxfp8, noop for DelayedScaling
scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape(
rowwise_scale_inv_shape
) )
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: if is_2x:
scale_inv = scale_inv.flatten()[ colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, rowwise_scale_inv_shape) : reduce(operator.mul, colwise_scale_inv_shape, 1)
].reshape(rowwise_scale_inv_shape) ].reshape(colwise_scale_inv_shape)
if is_2x:
colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_scale_inv_shape)
].reshape(colwise_scale_inv_shape)
return ( return (
out, out,
colwise_out, colwise_out,
...@@ -364,6 +386,8 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -364,6 +386,8 @@ class NormFwdPrimitive(BasePrimitive):
del zero_centered_gamma, epsilon, out_dtype, result_infos del zero_centered_gamma, epsilon, out_dtype, result_infos
del scale_dtype, scale_shapes, is_outer del scale_dtype, scale_shapes, is_outer
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_spec = (*x_spec[:-1], None)
if x_spec[-1] is not None: if x_spec[-1] is not None:
warnings.warn( warnings.warn(
f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! " f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! "
...@@ -371,34 +395,27 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -371,34 +395,27 @@ class NormFwdPrimitive(BasePrimitive):
"and hurt performance." "and hurt performance."
) )
out_sharding = NamedSharding( out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.out" colwise_out_spec = out_spec if is_2x else (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
) )
if is_2x:
colwise_out_sharding = out_sharding.duplicate_with_new_description(
"NormFwdPrimitive.colwise_out"
)
else:
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out"
)
rsigma_sharding = NamedSharding( rsigma_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma" mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma"
) )
mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu") mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,)
if norm_type == NVTE_Norm_Type.RMSNorm: mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu")
mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu")
scale_inv_spec = amax_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = out_spec
scale_inv_sharding = NamedSharding( scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale_inv" mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv"
) )
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax")
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv"
)
amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax")
output = ( output = (
out_sharding, out_sharding,
colwise_out_sharding, colwise_out_sharding,
...@@ -427,8 +444,11 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -427,8 +444,11 @@ class NormFwdPrimitive(BasePrimitive):
): ):
del result_infos, is_outer del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
g_spec = get_padded_spec(arg_infos[2]) g_spec = get_padded_spec(arg_infos[2])
b_spec = get_padded_spec(arg_infos[3]) b_spec = get_padded_spec(arg_infos[3])
out_spec = (*x_spec[:-1], None)
if x_spec[-1] is not None: if x_spec[-1] is not None:
warnings.warn( warnings.warn(
f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! " f"Does not support to shard hidden dim in {NormFwdPrimitive.name}! "
...@@ -445,43 +465,30 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -445,43 +465,30 @@ class NormFwdPrimitive(BasePrimitive):
f"{NormFwdPrimitive.name} does not support sharding of parameter beta " f"{NormFwdPrimitive.name} does not support sharding of parameter beta "
"Enforcing no sharding of parameters hidden dim! " "Enforcing no sharding of parameters hidden dim! "
) )
x_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec[:-1], None), desc="NormFwdPrimitive.x"
)
g_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.gamma")
b_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.beta")
out_sharding = x_sharding.duplicate_with_new_description("NormFwdPrimitive.out")
if is_2x:
colwise_out_sharding = out_sharding.duplicate_with_new_description(
"NormFwdPrimitive.colwise_out"
)
else:
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(None), desc="NormFwdPrimitive.colwise_out"
)
out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
colwise_out_spec = out_spec if is_2x else (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
)
rsigma_sharding = NamedSharding( rsigma_sharding = NamedSharding(
mesh, mesh, PartitionSpec(*x_spec[:-1]), desc="NormFwdPrimitive.rsigma"
PartitionSpec(*get_padded_spec(arg_infos[0])[:-1]),
desc="NormFwdPrimitive.rsigma",
) )
mu_sharding = rsigma_sharding.duplicate_with_new_description("NormFwdPrimitive.mu") mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,)
if norm_type == NVTE_Norm_Type.RMSNorm: mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu")
mu_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.mu")
scale_sharding = NamedSharding( scale_inv_spec = amax_spec = (None,)
mesh, PartitionSpec(*get_padded_spec(arg_infos[1])), desc="NormFwdPrimitive.scale" if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
) scale_inv_spec = amax_spec = scale_spec
scale_inv_sharding = scale_sharding.duplicate_with_new_description( elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
"NormFwdPrimitive.scale_inv" scale_inv_spec = out_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="NormFwdPrimitive.scale_inv"
) )
amax_sharding = NamedSharding(mesh, PartitionSpec(None), desc="NormFwdPrimitive.amax") amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax")
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING:
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="NormFwdPrimitive.scale_inv"
)
arg_shardings = (x_sharding, scale_sharding, g_sharding, b_sharding) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = ( out_shardings = (
out_sharding, out_sharding,
colwise_out_sharding, colwise_out_sharding,
...@@ -517,7 +524,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -517,7 +524,7 @@ class NormFwdPrimitive(BasePrimitive):
scale_shapes=scale_shapes, scale_shapes=scale_shapes,
is_outer=True, is_outer=True,
) )
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else: else:
global_updated_amax = local_amax global_updated_amax = local_amax
...@@ -534,6 +541,57 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -534,6 +541,57 @@ class NormFwdPrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
norm_type,
zero_centered_gamma,
epsilon,
out_dtype,
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
value_types,
result_types,
):
del (
zero_centered_gamma,
epsilon,
out_dtype,
scale_dtype,
scale_shapes,
is_outer,
mesh,
result_types,
)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="i", flatten_axis=-1
)
x_axes = scale_rules.input_spec
out = x_axes[:-1] + ("k",)
colwise_out = out if is_2x else ("…4",)
rsigma = x_axes[:-1]
mu = ("…5",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma
amax = ("…6",)
return SdyShardingRule(
(x_axes, ("…1",), ("…2",), ("…3",)),
(
out,
colwise_out,
scale_rules.rowwise_rule,
scale_rules.colwise_rule,
amax,
mu,
rsigma,
),
**scale_rules.factor_sizes,
)
register_primitive(NormFwdPrimitive) register_primitive(NormFwdPrimitive)
...@@ -737,6 +795,11 @@ class NormBwdPrimitive(BasePrimitive): ...@@ -737,6 +795,11 @@ class NormBwdPrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(*args):
del args
return "...0, ...1 i, ...2, ...3, ...4 -> ...1 j, k, l"
register_primitive(NormBwdPrimitive) register_primitive(NormBwdPrimitive)
...@@ -746,6 +809,10 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None) ...@@ -746,6 +809,10 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None)
JAX native layernorm implementation JAX native layernorm implementation
""" """
x_ = jnp.asarray(x, jnp.float32) x_ = jnp.asarray(x, jnp.float32)
if not is_norm_zero_centered_gamma_in_weight_dtype(
quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING
):
gamma = gamma.astype(jnp.float32)
mean = jnp.mean(x_, axis=-1, keepdims=True) mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(var + epsilon) rsigma = jax.lax.rsqrt(var + epsilon)
...@@ -767,6 +834,10 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None): ...@@ -767,6 +834,10 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None):
JAX native rmsnorm implementation JAX native rmsnorm implementation
""" """
x_ = jnp.asarray(x, jnp.float32) x_ = jnp.asarray(x, jnp.float32)
if not is_norm_zero_centered_gamma_in_weight_dtype(
quantizer.scaling_mode if quantizer else ScalingMode.NO_SCALING
):
gamma = gamma.astype(jnp.float32)
var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True) var = jnp.mean(jnp.square(x_), axis=-1, keepdims=True)
rsigma = jax.lax.rsqrt(var + epsilon) rsigma = jax.lax.rsqrt(var + epsilon)
normed_input = x_ * rsigma normed_input = x_ * rsigma
...@@ -816,7 +887,7 @@ def layernorm_fwd( ...@@ -816,7 +887,7 @@ def layernorm_fwd(
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
# TE/common does not support normalization with colwise only quantization yet # TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)
scale = ( scale = (
...@@ -824,7 +895,6 @@ def layernorm_fwd( ...@@ -824,7 +895,6 @@ def layernorm_fwd(
if isinstance(quantizer, DelayedScaleQuantizer) if isinstance(quantizer, DelayedScaleQuantizer)
else jnp.ones((1,), dtype=jnp.float32) else jnp.ones((1,), dtype=jnp.float32)
) )
if quantizer is None: if quantizer is None:
output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind( output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind(
x, x,
...@@ -835,7 +905,7 @@ def layernorm_fwd( ...@@ -835,7 +905,7 @@ def layernorm_fwd(
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
out_dtype=x.dtype, out_dtype=x.dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, is_2x=False,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
scale_shapes=((1,), (1,)), scale_shapes=((1,), (1,)),
...@@ -845,7 +915,7 @@ def layernorm_fwd( ...@@ -845,7 +915,7 @@ def layernorm_fwd(
is_2x2x = quantizer.is_2x2x() is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling # TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
is_2x2x = False is_2x2x = False
( (
rowwise_casted_output, rowwise_casted_output,
...@@ -864,7 +934,7 @@ def layernorm_fwd( ...@@ -864,7 +934,7 @@ def layernorm_fwd(
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x, is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape), scale_shapes=quantizer.get_scale_shapes(x.shape),
...@@ -873,7 +943,7 @@ def layernorm_fwd( ...@@ -873,7 +943,7 @@ def layernorm_fwd(
quantizer.update(updated_amax) quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
colwise_casted_output = jnp.transpose( colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
) )
...@@ -882,7 +952,7 @@ def layernorm_fwd( ...@@ -882,7 +952,7 @@ def layernorm_fwd(
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape. # So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor # The ScaledTensorFactory takes care of padding when creating the ScaledTensor
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes(
x.shape, is_padded=False x.shape, is_padded=False
) )
...@@ -900,8 +970,8 @@ def layernorm_fwd( ...@@ -900,8 +970,8 @@ def layernorm_fwd(
colwise_scale_inv=colwise_scale_inv, colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype, dq_dtype=x.dtype,
q_axis=quantizer.q_axis, q_layout=quantizer.q_layout,
layout=quantizer.get_layout(), data_layout=quantizer.get_data_layout(),
) )
return scaled_tensor, mu, rsigma return scaled_tensor, mu, rsigma
...@@ -997,7 +1067,7 @@ def rmsnorm_fwd( ...@@ -997,7 +1067,7 @@ def rmsnorm_fwd(
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
# TE/common does not support normalization with colwise only quantization yet # TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)
scale = ( scale = (
...@@ -1017,7 +1087,7 @@ def rmsnorm_fwd( ...@@ -1017,7 +1087,7 @@ def rmsnorm_fwd(
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
out_dtype=x.dtype, out_dtype=x.dtype,
scaling_mode=ScalingMode.NVTE_DELAYED_TENSOR_SCALING, scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, is_2x=False,
scale_dtype=jnp.float32, scale_dtype=jnp.float32,
scale_shapes=((), ()), scale_shapes=((), ()),
...@@ -1027,7 +1097,7 @@ def rmsnorm_fwd( ...@@ -1027,7 +1097,7 @@ def rmsnorm_fwd(
is_2x2x = quantizer.is_2x2x() is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling # TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
is_2x2x = False is_2x2x = False
( (
rowwise_casted_output, rowwise_casted_output,
...@@ -1046,7 +1116,7 @@ def rmsnorm_fwd( ...@@ -1046,7 +1116,7 @@ def rmsnorm_fwd(
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon, epsilon=epsilon,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x, is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape), scale_shapes=quantizer.get_scale_shapes(x.shape),
...@@ -1055,7 +1125,7 @@ def rmsnorm_fwd( ...@@ -1055,7 +1125,7 @@ def rmsnorm_fwd(
quantizer.update(updated_amax) quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING: if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
colwise_casted_output = jnp.transpose( colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
) )
...@@ -1064,7 +1134,7 @@ def rmsnorm_fwd( ...@@ -1064,7 +1134,7 @@ def rmsnorm_fwd(
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape. # So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor # The ScaledTensorFactory takes care of padding when creating the ScaledTensor
if quantizer.scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING: if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes(
x.shape, is_padded=False x.shape, is_padded=False
) )
...@@ -1082,8 +1152,8 @@ def rmsnorm_fwd( ...@@ -1082,8 +1152,8 @@ def rmsnorm_fwd(
colwise_scale_inv=colwise_scale_inv, colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode,
dq_dtype=x.dtype, dq_dtype=x.dtype,
q_axis=quantizer.q_axis, q_layout=quantizer.q_layout,
layout=quantizer.get_layout(), data_layout=quantizer.get_data_layout(),
) )
return scaled_tensor, rsigma return scaled_tensor, rsigma
......
...@@ -2,12 +2,15 @@ ...@@ -2,12 +2,15 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE custom ops for quantization""" """JAX/TE custom ops for quantization"""
import operator
from functools import reduce
from typing import Tuple, Optional from typing import Tuple, Optional
from packaging import version from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
import transformer_engine_jax import transformer_engine_jax
...@@ -24,7 +27,7 @@ from .misc import ( ...@@ -24,7 +27,7 @@ from .misc import (
) )
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
from ..quantize import Quantizer, QuantizeAxis, DelayedScaleQuantizer, ScalingMode from ..quantize import Quantizer, QuantizeLayout, DelayedScaleQuantizer, ScalingMode
if version.parse(jax.__version__) >= version.parse("0.5.0"): if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports from jax import ffi # pylint: disable=ungrouped-imports
...@@ -50,7 +53,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -50,7 +53,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
6, 6,
7, 7,
8, 8,
) # out_dtype, scaling_mode, q_axis, scale_dtype, scale_shapes, is_dbias, is_outer 9,
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, scale_shapes, is_dbias, is_outer
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -61,7 +65,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -61,7 +65,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
*, *,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
q_axis, q_layout,
flatten_axis,
scale_dtype, scale_dtype,
scale_shapes, scale_shapes,
is_dbias, is_dbias,
...@@ -73,49 +78,56 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -73,49 +78,56 @@ class DBiasQuantizePrimitive(BasePrimitive):
del scale_shapes del scale_shapes
dtype = dtypes.canonicalize_dtype(x_aval.dtype) dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_shape = x_aval.shape
assert scale_aval is None or scale_aval.dtype == jnp.float32 assert scale_aval is None or scale_aval.dtype == jnp.float32
rowwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype) if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
rowwise_out_shape = out_shape
if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): else:
rowwise_out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype) rowwise_out_shape = (1,)
rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype)
updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
else:
colwise_out_shape = out_shape
else:
colwise_out_shape = (1,)
colwise_scale_inv_shape = (1,)
colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)
scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
colwise_out_aval = jax.core.ShapedArray(shape=(1,), dtype=out_dtype) shape=colwise_scale_inv_shape, dtype=scale_dtype
colwise_scale_inv_aval = jax.core.ShapedArray(shape=(1,), dtype=scale_dtype) )
dbias_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
wkspace_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value):
t_shape = multidim_transpose(x_aval.shape)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value:
# Don't transpose output for MXFP8
t_shape = x_aval.shape
colwise_out_aval = x_aval.update(shape=t_shape, dtype=out_dtype)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
if is_dbias: if is_dbias:
gi_hidden_size = x_aval.shape[-1] dbias_shape = x_aval.shape[flatten_axis:]
dbias_shape = (gi_hidden_size,) gi_hidden_size = reduce(operator.mul, x_aval.shape[flatten_axis:], 1)
dbias_aval = x_aval.update(shape=dbias_shape, dtype=dtype)
(wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes( (wkspace_info,) = transformer_engine_jax.get_dbias_quantize_workspace_sizes(
x_aval.size // gi_hidden_size, x_aval.size // gi_hidden_size,
gi_hidden_size, gi_hidden_size,
jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(x_aval.dtype),
jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(out_dtype),
scaling_mode,
QuantizeLayout(
q_layout
), # For now until we have auto-decoding for QuantizeLayout enum
) )
wkspace_aval = x_aval.update( wkspace_shape = wkspace_info[0]
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1])
) else:
dbias_shape = (1,)
wkspace_shape = (1,)
wkspace_dtype = jnp.float32
dbias_aval = jax.core.ShapedArray(shape=dbias_shape, dtype=dtype)
wkspace_aval = jax.core.ShapedArray(shape=wkspace_shape, dtype=wkspace_dtype)
return ( return (
rowwise_out_aval, rowwise_out_aval,
...@@ -151,7 +163,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -151,7 +163,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
*, *,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
q_axis, q_layout,
flatten_axis,
scale_dtype, scale_dtype,
scale_shapes, scale_shapes,
is_dbias, is_dbias,
...@@ -168,8 +181,9 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -168,8 +181,9 @@ class DBiasQuantizePrimitive(BasePrimitive):
ctx, ctx,
x, x,
scale, scale,
scaling_mode=scaling_mode, scaling_mode=scaling_mode.value,
q_axis=q_axis, q_layout=q_layout,
flatten_axis=flatten_axis,
is_dbias=is_dbias, is_dbias=is_dbias,
) )
...@@ -179,7 +193,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -179,7 +193,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale, scale,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
q_axis, q_layout,
flatten_axis,
scale_dtype, scale_dtype,
scale_shapes, scale_shapes,
is_dbias, is_dbias,
...@@ -203,7 +218,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -203,7 +218,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale, scale,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
q_axis=q_axis, q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes, scale_shapes=scale_shapes,
is_dbias=is_dbias, is_dbias=is_dbias,
...@@ -211,16 +227,14 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -211,16 +227,14 @@ class DBiasQuantizePrimitive(BasePrimitive):
) )
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode scaling_mode
).get_scale_shape_2x(x.shape, is_padded=False) ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis)
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: scale_inv = jax.lax.slice(
if q_axis in (QuantizeAxis.ROWWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape
scale_inv = jax.lax.slice( )
scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
) colwise_scale_inv = jax.lax.slice(
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
colwise_scale_inv = jax.lax.slice( )
colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape
)
return ( return (
out, out,
colwise_out, colwise_out,
...@@ -237,7 +251,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -237,7 +251,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
*, *,
out_dtype, out_dtype,
scaling_mode, scaling_mode,
q_axis, q_layout,
flatten_axis,
scale_dtype, scale_dtype,
scale_shapes, scale_shapes,
is_dbias, is_dbias,
...@@ -260,7 +275,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -260,7 +275,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale, scale,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
q_axis=q_axis, q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes, scale_shapes=scale_shapes,
is_dbias=is_dbias, is_dbias=is_dbias,
...@@ -272,7 +288,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -272,7 +288,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
def infer_sharding_from_operands( def infer_sharding_from_operands(
out_dtype, out_dtype,
scaling_mode, scaling_mode,
q_axis, q_layout,
flatten_axis,
scale_dtype, scale_dtype,
scale_shapes, scale_shapes,
is_dbias, is_dbias,
...@@ -281,16 +298,17 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -281,16 +298,17 @@ class DBiasQuantizePrimitive(BasePrimitive):
arg_infos, arg_infos,
result_infos, result_infos,
): ):
del (out_dtype, result_infos, scale_dtype, scale_shapes, is_dbias, is_outer) # Unused. del (out_dtype, result_infos, scale_dtype, scale_shapes, is_outer) # Unused.
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding( out_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*x_spec[:-1], x_spec[-1]), PartitionSpec(*x_spec),
desc="DBiasQuantizePrimitive.out_sharding", desc="DBiasQuantizePrimitive.out_sharding",
) )
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec) colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else: else:
colwise_out_spec = x_spec colwise_out_spec = x_spec
else: else:
...@@ -300,26 +318,35 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -300,26 +318,35 @@ class DBiasQuantizePrimitive(BasePrimitive):
PartitionSpec(*colwise_out_spec), PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding", desc="DBiasQuantizePrimitive.colwise_out_sharding",
) )
scale_inv_sharding = NamedSharding(
dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*get_padded_spec(arg_infos[1])), PartitionSpec(*dbias_spec),
desc="DBiasQuantizePrimitive.scale_inv", desc="DBiasQuantizePrimitive.dbias_sharding",
) )
amax_sharding = scale_inv_sharding.duplicate_with_new_description(
desc="DBiasQuantizePrimitive.amax_sharding" scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
) )
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: amax_sharding = NamedSharding(
scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DBiasQuantizePrimitive.colwise_scale_inv"
) )
dbias_sharding = NamedSharding( colwise_scale_inv_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(x_spec[-1]), PartitionSpec(*colwise_scale_inv_spec),
desc="DBiasQuantizePrimitive.dbias_sharding", desc="DBiasQuantizePrimitive.colwise_scale_inv",
) )
return ( return (
out_sharding, out_sharding,
colwise_out_sharding, colwise_out_sharding,
...@@ -333,7 +360,8 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -333,7 +360,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
def partition( def partition(
out_dtype, out_dtype,
scaling_mode, scaling_mode,
q_axis, q_layout,
flatten_axis,
scale_dtype, scale_dtype,
scale_shapes, scale_shapes,
is_dbias, is_dbias,
...@@ -344,14 +372,15 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -344,14 +372,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
): ):
del result_infos, is_outer del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0]) x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding( out_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*x_spec[:-1], x_spec[-1]), PartitionSpec(*x_spec),
desc="DBiasQuantizePrimitive.out_sharding", desc="DBiasQuantizePrimitive.out_sharding",
) )
if q_axis in (QuantizeAxis.COLWISE.value, QuantizeAxis.ROWWISE_COLWISE.value): if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out_spec = multidim_transpose(x_spec) colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else: else:
colwise_out_spec = x_spec colwise_out_spec = x_spec
else: else:
...@@ -361,26 +390,35 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -361,26 +390,35 @@ class DBiasQuantizePrimitive(BasePrimitive):
PartitionSpec(*colwise_out_spec), PartitionSpec(*colwise_out_spec),
desc="DBiasQuantizePrimitive.colwise_out_sharding", desc="DBiasQuantizePrimitive.colwise_out_sharding",
) )
scale_inv_sharding = NamedSharding(
dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,)
dbias_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(*get_padded_spec(arg_infos[1])), PartitionSpec(*dbias_spec),
desc="DBiasQuantizePrimitive.scale_inv", desc="DBiasQuantizePrimitive.dbias_sharding",
) )
amax_sharding = scale_inv_sharding.duplicate_with_new_description(
desc="DBiasQuantizePrimitive.amax_sharding" scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
scale_inv_spec = amax_spec = scale_spec
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value:
scale_inv_spec = x_spec
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
colwise_scale_inv_spec = scale_inv_spec
scale_inv_sharding = NamedSharding(
mesh, PartitionSpec(*scale_inv_spec), desc="DBiasQuantizePrimitive.scale_inv"
) )
if scaling_mode == ScalingMode.NVTE_MXFP8_1D_SCALING.value: amax_sharding = NamedSharding(
scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*amax_spec), desc="DBiasQuantizePrimitive.amax"
mesh, PartitionSpec(*x_spec), desc="DBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding = scale_inv_sharding.duplicate_with_new_description(
"DBiasQuantizePrimitive.colwise_scale_inv"
) )
dbias_sharding = NamedSharding( colwise_scale_inv_sharding = NamedSharding(
mesh, mesh,
PartitionSpec(x_spec[-1]), PartitionSpec(*colwise_scale_inv_spec),
desc="DBiasQuantizePrimitive.dbias_sharding", desc="DBiasQuantizePrimitive.colwise_scale_inv",
) )
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
out_shardings = ( out_shardings = (
out_sharding, out_sharding,
...@@ -404,14 +442,15 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -404,14 +442,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale, scale,
out_dtype=out_dtype, out_dtype=out_dtype,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
q_axis=q_axis, q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype, scale_dtype=scale_dtype,
scale_shapes=scale_shapes, scale_shapes=scale_shapes,
is_dbias=is_dbias, is_dbias=is_dbias,
is_outer=True, is_outer=True,
) )
if scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh)
else: else:
global_updated_amax = local_amax global_updated_amax = local_amax
...@@ -432,53 +471,91 @@ class DBiasQuantizePrimitive(BasePrimitive): ...@@ -432,53 +471,91 @@ class DBiasQuantizePrimitive(BasePrimitive):
return mesh, sharded_impl, out_shardings, arg_shardings return mesh, sharded_impl, out_shardings, arg_shardings
@staticmethod
def shardy_sharding_rule(
out_dtype,
scaling_mode,
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, scale_dtype, scale_shapes, is_outer, mesh, result_types
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="i", flatten_axis=flatten_axis
)
x_axes = scale_rules.input_spec
colwise_scale_inv = scale_rules.colwise_rule
out = x_axes
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
else:
colwise_out = x_axes
else:
colwise_out = ("j",)
colwise_scale_inv = ("k",)
dbias = x_axes[flatten_axis:] if is_dbias else ("l",)
amax = ("m",)
return SdyShardingRule(
(x_axes, ("…1",)),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
**scale_rules.factor_sizes,
)
register_primitive(DBiasQuantizePrimitive) register_primitive(DBiasQuantizePrimitive)
def _jax_quantize(x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None): def _jax_quantize(
x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
):
if quantizer is None: if quantizer is None:
return x return x
return quantizer.quantize(x, dq_dtype=dq_dtype) return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
def _jax_dbias(dx: jnp.ndarray): def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1):
assert flatten_axis < 0
dtype = dtype or dx.dtype
dbias = jnp.sum( dbias = jnp.sum(
dx, dx.astype(jnp.float32),
axis=tuple(range(dx.ndim - 1)), axis=tuple(range(dx.ndim + flatten_axis)),
keepdims=False, keepdims=False,
) )
dbias = dbias.ravel() # C++ function returns an 1D array for dbias return dbias.astype(dtype)
return dbias
def _jax_quantize_dbias( def _jax_quantize_dbias(
x, x,
quantizer: Quantizer = None, quantizer: Quantizer = None,
dq_dtype: Optional[jnp.dtype] = None, dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
): ):
if quantizer is None: if quantizer is None:
return x, None return x, None
return quantizer.quantize(x, dq_dtype=dq_dtype), _jax_dbias(x) return (
quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
_jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis),
def _jax_dbias(
dx: jnp.ndarray,
):
dbias = jnp.sum(
dx.astype(jnp.float32),
axis=tuple(range(dx.ndim - 1)),
keepdims=False,
) )
dbias = dbias.ravel() # C++ function returns an 1D array for dbias
return dbias.astype(dx.dtype)
def _quantize_impl( def _quantize_dbias_impl(
x: jnp.ndarray, x: jnp.ndarray,
quantizer: Quantizer, quantizer: Quantizer,
is_dbias: bool = False, is_dbias: bool = False,
dq_dtype: Optional[jnp.dtype] = None, dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
""" """
Cast wrapper Cast wrapper
...@@ -488,40 +565,51 @@ def _quantize_impl( ...@@ -488,40 +565,51 @@ def _quantize_impl(
quantizer is not None quantizer is not None
), "quantizer must be provided if dq_dtype is provided" ), "quantizer must be provided if dq_dtype is provided"
dq_dtype = dq_dtype or x.dtype
if not DBiasQuantizePrimitive.enabled(): if not DBiasQuantizePrimitive.enabled():
if is_dbias: if is_dbias:
return _jax_quantize_dbias( return _jax_quantize_dbias(
x, x,
quantizer=quantizer, quantizer=quantizer,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
) )
return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None return (
_jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
None,
)
# TE/common doesn't support colwise only quantization yet # TE/common doesn't support colwise only quantization yet
if quantizer is not None and quantizer.q_axis == QuantizeAxis.COLWISE: if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
if is_dbias: if is_dbias:
return _jax_quantize_dbias( return _jax_quantize_dbias(
x, x,
quantizer=quantizer, quantizer=quantizer,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
) )
return _jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype), None return (
_jax_quantize(x, quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis),
None,
)
scale = jnp.empty((), jnp.float32) scale = jnp.empty((), jnp.float32)
# TE/common dbias_quantize does not support 1x on arch < 100 # TE/common dbias_quantize does not support 1x on arch < 100
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out, _ = _quantize_impl( out, _ = _quantize_dbias_impl(
x=x, x=x,
is_dbias=False, is_dbias=False,
quantizer=quantizer, quantizer=quantizer,
dq_dtype=dq_dtype, dq_dtype=dq_dtype,
flatten_axis=flatten_axis,
) )
dbias = _jax_dbias(x) dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return out, dbias return out, dbias
if quantizer is None: if quantizer is None:
if is_dbias: if is_dbias:
return x, _jax_dbias(x) return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return x, None return x, None
if isinstance(quantizer, DelayedScaleQuantizer): if isinstance(quantizer, DelayedScaleQuantizer):
...@@ -539,14 +627,15 @@ def _quantize_impl( ...@@ -539,14 +627,15 @@ def _quantize_impl(
scale, scale,
out_dtype=quantizer.q_dtype, out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value, scaling_mode=quantizer.scaling_mode.value,
q_axis=quantizer.q_axis.value, q_layout=quantizer.q_layout.value,
flatten_axis=flatten_axis,
scale_dtype=quantizer.get_scale_dtype(), scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape), scale_shapes=quantizer.get_scale_shapes(x.shape, flatten_axis=flatten_axis),
is_dbias=is_dbias, is_dbias=is_dbias,
is_outer=True, is_outer=True,
) )
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
if quantizer.scaling_mode == ScalingMode.NVTE_DELAYED_TENSOR_SCALING and quantizer.is_2x2x(): if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv colwise_scale_inv = rowwise_scale_inv
quantizer.update(updated_amax) quantizer.update(updated_amax)
...@@ -557,18 +646,18 @@ def _quantize_impl( ...@@ -557,18 +646,18 @@ def _quantize_impl(
colwise_data=colwise_casted_output, colwise_data=colwise_casted_output,
colwise_scale_inv=colwise_scale_inv, colwise_scale_inv=colwise_scale_inv,
scaling_mode=quantizer.scaling_mode, scaling_mode=quantizer.scaling_mode,
dq_dtype=dq_dtype if dq_dtype is not None else x.dtype, dq_dtype=dq_dtype,
q_axis=quantizer.q_axis, q_layout=quantizer.q_layout,
layout=quantizer.get_layout(), data_layout=quantizer.get_data_layout(),
flatten_axis=flatten_axis,
) )
return out, dbias return out, dbias.astype(dq_dtype)
# TODO(Phuong): do not expose dq_dtype to users
def quantize( def quantize(
x: jnp.ndarray, x: jnp.ndarray,
quantizer: Quantizer, quantizer: Quantizer,
dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1,
) -> Tuple[ScaledTensor]: ) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer. """Quantize input tensor according to the quantizer.
...@@ -576,26 +665,25 @@ def quantize( ...@@ -576,26 +665,25 @@ def quantize(
x: Input tensor to be quantized. x: Input tensor to be quantized.
Shape: (..., K) where K is the hidden size. Shape: (..., K) where K is the hidden size.
quantizer: Quantizer for FP8 quantization of the output. quantizer: Quantizer for FP8 quantization of the output.
dq_dtype: Optional dtype for dequantization. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
If None, uses the same dtype as the input tensor. Defaults to -1.
Returns: Returns:
A ScaledTensor containing the quantized input tensor. A ScaledTensor containing the quantized input tensor.
""" """
out, _ = _quantize_impl( out, _ = _quantize_dbias_impl(
x, x,
quantizer=quantizer, quantizer=quantizer,
dq_dtype=dq_dtype, flatten_axis=flatten_axis,
) )
return out return out
# TODO(Phuong): do not expose dq_dtype to users
def quantize_dbias( def quantize_dbias(
dz: jnp.ndarray, dz: jnp.ndarray,
quantizer: Quantizer, quantizer: Quantizer,
is_dbias: bool = True, is_dbias: bool = True,
dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient. """Quantize input tensor and compute bias gradient.
...@@ -604,8 +692,8 @@ def quantize_dbias( ...@@ -604,8 +692,8 @@ def quantize_dbias(
Shape: (..., K) where K is the hidden size. Shape: (..., K) where K is the hidden size.
quantizer: Quantizer for FP8 quantization of the output. quantizer: Quantizer for FP8 quantization of the output.
is_dbias: If True, compute bias gradient. Defaults to True. is_dbias: If True, compute bias gradient. Defaults to True.
dq_dtype: Optional dtype for dequantization. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
If None, uses the same dtype as the input tensor. Defaults to -1.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -614,9 +702,6 @@ def quantize_dbias( ...@@ -614,9 +702,6 @@ def quantize_dbias(
- The bias gradient tensor. - The bias gradient tensor.
Shape: (K,) or empty if is_dbias is False. Shape: (K,) or empty if is_dbias is False.
""" """
return _quantize_impl( return _quantize_dbias_impl(
dz, dz, quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis
quantizer=quantizer,
is_dbias=is_dbias,
dq_dtype=dq_dtype,
) )
...@@ -31,6 +31,9 @@ __all__ = [ ...@@ -31,6 +31,9 @@ __all__ = [
"scaled_upper_triang_masked_softmax_fwd", "scaled_upper_triang_masked_softmax_fwd",
"scaled_upper_triang_masked_softmax_bwd", "scaled_upper_triang_masked_softmax_bwd",
"is_softmax_kernel_available", "is_softmax_kernel_available",
"jax_scaled_softmax",
"jax_scaled_masked_softmax",
"jax_scaled_upper_triang_masked_softmax",
] ]
...@@ -330,6 +333,11 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -330,6 +333,11 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
) )
@staticmethod
def shardy_sharding_rule(*args):
del args
return "... -> ..."
register_primitive(ScaledSoftmaxFwdPrimitive) register_primitive(ScaledSoftmaxFwdPrimitive)
...@@ -400,6 +408,11 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -400,6 +408,11 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
) )
@staticmethod
def shardy_sharding_rule(*args):
del args
return "..., ... -> ..."
register_primitive(ScaledSoftmaxBwdPrimitive) register_primitive(ScaledSoftmaxBwdPrimitive)
...@@ -412,7 +425,7 @@ def scaled_softmax_bwd( ...@@ -412,7 +425,7 @@ def scaled_softmax_bwd(
Return FP16/BF16 tensor Return FP16/BF16 tensor
""" """
if not ScaledSoftmaxBwdPrimitive.enabled(): if not ScaledSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp(partial(_jax_scaled_softmax, scale_factor=scale_factor), logits) _, vjp_func = jax.vjp(partial(jax_scaled_softmax, scale_factor=scale_factor), logits)
return vjp_func(dz)[0] return vjp_func(dz)[0]
return ScaledSoftmaxBwdPrimitive.outer_primitive.bind( return ScaledSoftmaxBwdPrimitive.outer_primitive.bind(
...@@ -525,6 +538,11 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -525,6 +538,11 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
) )
@staticmethod
def shardy_sharding_rule(*args):
del args
return "...1, ...2 -> ...1"
register_primitive(ScaledMaskedSoftmaxFwdPrimitive) register_primitive(ScaledMaskedSoftmaxFwdPrimitive)
...@@ -596,6 +614,11 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -596,6 +614,11 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
) )
@staticmethod
def shardy_sharding_rule(*args):
del args
return "..., ... -> ..."
register_primitive(ScaledMaskedSoftmaxBwdPrimitive) register_primitive(ScaledMaskedSoftmaxBwdPrimitive)
...@@ -682,6 +705,11 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -682,6 +705,11 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
result_infos, result_infos,
) )
@staticmethod
def shardy_sharding_rule(*args):
del args
return "... -> ..."
register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive) register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)
...@@ -761,15 +789,26 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive): ...@@ -761,15 +789,26 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
result_infos, result_infos,
) )
@staticmethod
def shardy_sharding_rule(*args):
del args
return "..., ... -> ..."
register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
def _jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float): def jax_scaled_softmax(logits: jnp.ndarray, scale_factor: float):
"""
JAX based implementation of scaled softmax
"""
return jax.nn.softmax(scale_factor * logits) return jax.nn.softmax(scale_factor * logits)
def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float): def jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_factor: float):
"""
JAX based implementation of scaled and masked softmax
"""
if mask is not None: if mask is not None:
logits += jax.lax.select( logits += jax.lax.select(
mask > 0, mask > 0,
...@@ -779,7 +818,10 @@ def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_fac ...@@ -779,7 +818,10 @@ def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_fac
return jax.nn.softmax(logits * scale_factor) return jax.nn.softmax(logits * scale_factor)
def _jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float): def jax_scaled_upper_triang_masked_softmax(logits: jnp.ndarray, scale_factor: float):
"""
JAX based implementation of scaled and upper triangle masked softmax
"""
mask = 1 - jnp.tril(jnp.ones_like(logits)) mask = 1 - jnp.tril(jnp.ones_like(logits))
logits += jax.lax.select( logits += jax.lax.select(
mask > 0, mask > 0,
...@@ -795,7 +837,7 @@ def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray: ...@@ -795,7 +837,7 @@ def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
Return FP16/BF16 tensor Return FP16/BF16 tensor
""" """
if not ScaledSoftmaxFwdPrimitive.enabled(): if not ScaledSoftmaxFwdPrimitive.enabled():
return _jax_scaled_softmax(logits, scale_factor) return jax_scaled_softmax(logits, scale_factor)
return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor) return ScaledSoftmaxFwdPrimitive.outer_primitive.bind(logits, scale_factor=scale_factor)
...@@ -807,7 +849,7 @@ def scaled_masked_softmax_fwd( ...@@ -807,7 +849,7 @@ def scaled_masked_softmax_fwd(
Return FP16/BF16 tensor Return FP16/BF16 tensor
""" """
if not ScaledMaskedSoftmaxFwdPrimitive.enabled(): if not ScaledMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_masked_softmax(logits, mask, scale_factor) return jax_scaled_masked_softmax(logits, mask, scale_factor)
return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind( return ScaledMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, mask, scale_factor=scale_factor logits, mask, scale_factor=scale_factor
) )
...@@ -826,7 +868,7 @@ def scaled_masked_softmax_bwd( ...@@ -826,7 +868,7 @@ def scaled_masked_softmax_bwd(
""" """
if not ScaledMaskedSoftmaxBwdPrimitive.enabled(): if not ScaledMaskedSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp( _, vjp_func = jax.vjp(
partial(_jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask partial(jax_scaled_masked_softmax, scale_factor=scale_factor), logits, mask
) )
return vjp_func(dz)[0] return vjp_func(dz)[0]
return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind( return ScaledMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
...@@ -840,7 +882,7 @@ def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: fl ...@@ -840,7 +882,7 @@ def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: fl
Return FP16/BF16 tensor Return FP16/BF16 tensor
""" """
if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled(): if not ScaledUpperTriangMaskedSoftmaxFwdPrimitive.enabled():
return _jax_scaled_upper_triang_masked_softmax(logits, scale_factor) return jax_scaled_upper_triang_masked_softmax(logits, scale_factor)
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind( return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.outer_primitive.bind(
logits, scale_factor=scale_factor logits, scale_factor=scale_factor
) )
...@@ -855,7 +897,7 @@ def scaled_upper_triang_masked_softmax_bwd( ...@@ -855,7 +897,7 @@ def scaled_upper_triang_masked_softmax_bwd(
""" """
if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled(): if not ScaledUpperTriangMaskedSoftmaxBwdPrimitive.enabled():
_, vjp_func = jax.vjp( _, vjp_func = jax.vjp(
partial(_jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits partial(jax_scaled_upper_triang_masked_softmax, scale_factor=scale_factor), logits
) )
return vjp_func(dz)[0] return vjp_func(dz)[0]
return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind( return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.outer_primitive.bind(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment