Commit 27ddce40 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents d262ef4c 5b3092a0
...@@ -335,6 +335,7 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu ...@@ -335,6 +335,7 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu
static_cast<const CType *>(output.scale.dptr), static_cast<const CType *>(output.scale.dptr),
static_cast<CType *>(output.amax.dptr), static_cast<CType *>(output.amax.dptr),
static_cast<CType *>(output.scale_inv.dptr), row_length, num_rows); static_cast<CType *>(output.scale_inv.dptr), row_length, num_rows);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} else { } else {
NVTE_ERROR("Not implemented scaling mode: ", to_string(output.scaling_mode)); NVTE_ERROR("Not implemented scaling mode: ", to_string(output.scaling_mode));
......
...@@ -27,7 +27,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor ...@@ -27,7 +27,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor
SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &scale_inv_t, SimpleTensor &output,
SimpleTensor &output_t, const float epsilon, SimpleTensor &output_t, const float epsilon,
const bool return_transpose, const bool pow_2_scale, const bool return_transpose, const bool pow_2_scale,
cudaStream_t stream); const SimpleTensor &noop_tensor, cudaStream_t stream);
// enum class for rowwise usage // enum class for rowwise usage
enum class FP8BlockwiseRowwiseOption { enum class FP8BlockwiseRowwiseOption {
...@@ -59,7 +59,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor ...@@ -59,7 +59,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor
SimpleTensor &output_t, const float epsilon, SimpleTensor &output_t, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseRowwiseOption rowwise_option,
FP8BlockwiseColumnwiseOption columnwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scale, cudaStream_t stream); const bool pow_2_scale, const SimpleTensor &noop_tensor,
cudaStream_t stream);
} // namespace transformer_engine::detail } // namespace transformer_engine::detail
......
...@@ -269,6 +269,7 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt ...@@ -269,6 +269,7 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt
reinterpret_cast<InputType *>(dbias->data.dptr), reinterpret_cast<InputType *>(dbias->data.dptr),
reinterpret_cast<const fp32 *>(workspace.data.dptr), reduce_dbias_row_length, reinterpret_cast<const fp32 *>(workspace.data.dptr), reduce_dbias_row_length,
reduce_dbias_num_rows); reduce_dbias_num_rows);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ComputeType, typename Param, template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ComputeType, typename Param,
...@@ -787,20 +788,21 @@ void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor * ...@@ -787,20 +788,21 @@ void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *
NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape.");
NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape.");
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
cudaFuncSetAttribute((const void *) NVTE_CHECK_CUDA(cudaFuncSetAttribute((const void *)
cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, IS_ACT, ComputeType,
Param, nvec_in, nvec_out, Empty, OP>, Param, nvec_in, nvec_out, Empty, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); cudaFuncAttributePreferredSharedMemoryCarveout, 100));
#else #else
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, IS_ACT, ComputeType,
Param, nvec_in, nvec_out, Empty, OP>, Param, nvec_in, nvec_out, Empty, OP>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); cudaFuncAttributePreferredSharedMemoryCarveout, 100));
#endif #endif
cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Param, cast_transpose_fused_kernel_notaligned<IS_DBIAS, IS_DACT, IS_ACT, ComputeType, Param,
nvec_in, nvec_out, Empty, OP> nvec_in, nvec_out, Empty, OP>
<<<num_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>( <<<num_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
param, row_length, num_rows, num_tiles); param, row_length, num_rows, num_tiles);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
...@@ -1254,15 +1256,15 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu ...@@ -1254,15 +1256,15 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu
(THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>); (THREADS_PER_WARP + 1) * sizeof(Vec<OutputType, nvec_out>);
if (full_tile) { if (full_tile) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
cudaFuncSetAttribute((const void *) NVTE_CHECK_CUDA(cudaFuncSetAttribute((const void *)
dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType, dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType,
OutputType, Empty, OP1, OP2>, OutputType, Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); cudaFuncAttributePreferredSharedMemoryCarveout, 100));
#else #else
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType, dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType,
OutputType, Empty, OP1, OP2>, OutputType, Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); cudaFuncAttributePreferredSharedMemoryCarveout, 100));
#endif #endif
dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType, OutputType, dgated_act_cast_transpose_kernel<nvec_in, nvec_out, ComputeType, InputType, OutputType,
...@@ -1276,17 +1278,18 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu ...@@ -1276,17 +1278,18 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu
reinterpret_cast<fp32 *>(output->amax.dptr), reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), row_length, num_rows, reinterpret_cast<fp32 *>(output->scale_inv.dptr), row_length, num_rows,
n_tiles); n_tiles);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
cudaFuncSetAttribute((const void *) NVTE_CHECK_CUDA(cudaFuncSetAttribute((const void *)
dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType, dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType,
InputType, OutputType, Empty, OP1, OP2>, InputType, OutputType, Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); cudaFuncAttributePreferredSharedMemoryCarveout, 100));
#else #else
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType, dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType,
InputType, OutputType, Empty, OP1, OP2>, InputType, OutputType, Empty, OP1, OP2>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); cudaFuncAttributePreferredSharedMemoryCarveout, 100));
#endif #endif
dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType, InputType, dgated_act_cast_transpose_kernel_notaligned<nvec_in, nvec_out, ComputeType, InputType,
OutputType, Empty, OP1, OP2> OutputType, Empty, OP1, OP2>
...@@ -1299,6 +1302,7 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu ...@@ -1299,6 +1302,7 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu
reinterpret_cast<fp32 *>(output->amax.dptr), reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), row_length, num_rows, reinterpret_cast<fp32 *>(output->scale_inv.dptr), row_length, num_rows,
n_tiles); n_tiles);
NVTE_CHECK_CUDA(cudaGetLastError());
}); // NOLINT(*) }); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
......
...@@ -258,6 +258,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten ...@@ -258,6 +258,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten
multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType> multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
kernel_args_aligned.num_tensors = 0; kernel_args_aligned.num_tensors = 0;
} }
if (kernel_args_unaligned.num_tensors == kMaxTensorsPerKernel) { if (kernel_args_unaligned.num_tensors == kMaxTensorsPerKernel) {
...@@ -271,6 +272,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten ...@@ -271,6 +272,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten
multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType> multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
kernel_args_unaligned.num_tensors = 0; kernel_args_unaligned.num_tensors = 0;
} }
...@@ -311,6 +313,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten ...@@ -311,6 +313,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten
multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType> multi_cast_transpose_kernel<nvec_in, nvec_out, true, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_aligned);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
} }
if (kernel_args_unaligned.num_tensors > 0) { if (kernel_args_unaligned.num_tensors > 0) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
...@@ -323,6 +326,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten ...@@ -323,6 +326,7 @@ void multi_cast_transpose(const std::vector<Tensor*> input_list, std::vector<Ten
multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType> multi_cast_transpose_kernel<nvec_in, nvec_out, false, fp32, InputType, OutputType>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args_unaligned);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "common/common.h" #include "common/common.h"
#include "common/recipe/recipe_common.cuh" #include "common/recipe/recipe_common.cuh"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh" #include "common/util/ptx.cuh"
#include "common/utils.cuh" #include "common/utils.cuh"
...@@ -83,11 +82,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -83,11 +82,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
const __grid_constant__ CUtensorMap tensor_map_output_t, const __grid_constant__ CUtensorMap tensor_map_output_t,
#endif #endif
bool pow_2_scaling) { bool pow_2_scaling, const float* noop_ptr) {
using IVec = Vec<IType, THREAD_TILE_DIM_X>; using IVec = Vec<IType, THREAD_TILE_DIM_X>;
using OVecCast = Vec<OType, THREAD_TILE_DIM_X>; using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>; using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
// shared mem for amax reduction in entire block, each warp produces one amax, there are // shared mem for amax reduction in entire block, each warp produces one amax, there are
// NUM_WARPS_IN_BLOCK amax to reduce // NUM_WARPS_IN_BLOCK amax to reduce
__shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK];
...@@ -185,12 +188,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -185,12 +188,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
} }
} }
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
// store to global memory.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
// Step 3: Store cast output, Step 4: do transpose within thread tile // Step 3: Store cast output, Step 4: do transpose within thread tile
OVecCast tmp_output_c; OVecCast tmp_output_c;
...@@ -280,11 +277,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose ...@@ -280,11 +277,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
bool pow_2_scaling) { bool pow_2_scaling, const float* noop_ptr) {
using IVec = Vec<IType, THREAD_TILE_DIM_X>; using IVec = Vec<IType, THREAD_TILE_DIM_X>;
using OVecCast = Vec<OType, THREAD_TILE_DIM_X>; using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>; using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
// shared mem for amax reduction in entire block, each warp produces one amax, there are // shared mem for amax reduction in entire block, each warp produces one amax, there are
// NUM_WARPS_IN_BLOCK amax to reduce // NUM_WARPS_IN_BLOCK amax to reduce
__shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK];
...@@ -426,12 +427,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose ...@@ -426,12 +427,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
} }
} }
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
// store to global memory.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
// Step 3: Store cast output, Step 4: do transpose within thread tile // Step 3: Store cast output, Step 4: do transpose within thread tile
// Edge case: in the non-full tile case, there are three subcases // Edge case: in the non-full tile case, there are three subcases
// for full thread tile, it's the same thing here // for full thread tile, it's the same thing here
...@@ -502,11 +497,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK64) ...@@ -502,11 +497,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK64)
CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
bool pow_2_scaling) { bool pow_2_scaling, const float* noop_ptr) {
using IVec = Vec<IType, THREAD_TILE_DIM_X>; using IVec = Vec<IType, THREAD_TILE_DIM_X>;
using OVecCast = Vec<OType, THREAD_TILE_DIM_X>; using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>; using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
// shared mem for amax reduction in entire block, each warp produces one amax, there are // shared mem for amax reduction in entire block, each warp produces one amax, there are
// NUM_WARPS_IN_BLOCK amax to reduce // NUM_WARPS_IN_BLOCK amax to reduce
__shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK64]; __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK64];
...@@ -656,11 +655,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK64) ...@@ -656,11 +655,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK64)
CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
bool pow_2_scaling) { bool pow_2_scaling, const float* noop_ptr) {
using IVec = Vec<IType, THREAD_TILE_DIM_X>; using IVec = Vec<IType, THREAD_TILE_DIM_X>;
using OVecCast = Vec<OType, THREAD_TILE_DIM_X>; using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>; using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
// shared mem for amax reduction in entire block, each warp produces one amax, there are // shared mem for amax reduction in entire block, each warp produces one amax, there are
// NUM_WARPS_IN_BLOCK amax to reduce // NUM_WARPS_IN_BLOCK amax to reduce
__shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK64]; __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK64];
...@@ -894,7 +897,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -894,7 +897,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& scale_inv_t, SimpleTensor& output,
SimpleTensor& output_t, const float epsilon, SimpleTensor& output_t, const float epsilon,
const bool return_transpose, const bool pow_2_scale, const bool return_transpose, const bool pow_2_scale,
cudaStream_t stream) { const SimpleTensor& noop_tensor, cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_square_blockwise); NVTE_API_CALL(quantize_transpose_square_blockwise);
checkCuDriverContext(stream); checkCuDriverContext(stream);
...@@ -915,6 +918,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -915,6 +918,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
size_t scale_t_stride_x = 0; size_t scale_t_stride_x = 0;
size_t scale_t_stride_y = 0; size_t scale_t_stride_y = 0;
const float* noop_ptr = reinterpret_cast<const float*>(noop_tensor.dptr);
if (return_transpose) { if (return_transpose) {
NVTE_CHECK(output_t.shape.size() == input.shape.size(), NVTE_CHECK(output_t.shape.size() == input.shape.size(),
"output_t must have same number of dimensions as input."); "output_t must have same number of dimensions as input.");
...@@ -939,15 +944,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -939,15 +944,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
#else #else
const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM); const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM);
const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM); const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM);
dim3 grid(num_blocks_x, num_blocks_y, 1);
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
cudaLaunchConfig_t cfg = {grid, THREADS_PER_BLOCK, 0, stream, NULL, 0};
if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >= 90) {
cfg.attrs = attribute;
cfg.numAttrs = 1;
}
#endif #endif
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
...@@ -962,6 +958,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -962,6 +958,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
dim3 grid(num_blocks_x, num_blocks_y, 1); dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0; const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0;
#else #else
dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = const bool full_tile =
row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0;
#endif #endif
...@@ -972,28 +969,26 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -972,28 +969,26 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
tensor_map_output_trans = tensor_map_output_trans =
get_tensor_map<OutputType>(output_t, num_rows, row_length); get_tensor_map<OutputType>(output_t, num_rows, row_length);
} }
cudaLaunchKernelEx(&cfg, block_scaled_cast_transpose_kernel<kReturnTranspose, float, InputType, OutputType>
block_scaled_cast_transpose_kernel<kReturnTranspose, float, <<<grid, THREADS_PER_BLOCK, 0, stream>>>(
InputType, OutputType>,
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr), reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
scale_t_stride_y, epsilon, tensor_map_output_trans, pow_2_scale); tensor_map_output_trans, pow_2_scale, noop_ptr);
} else { } else {
cudaLaunchKernelEx( block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType,
&cfg, OutputType>
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, <<<grid, THREADS_PER_BLOCK, 0, stream>>>(
InputType, OutputType>,
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr), reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
pow_2_scale); pow_2_scale, noop_ptr);
#else #else
while (true) { while (true) {
if (128 == block_len) { if (128 == block_len) {
...@@ -1006,7 +1001,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1006,7 +1001,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y,
epsilon, pow_2_scale); epsilon, pow_2_scale, noop_ptr);
break; break;
} }
if (64 == block_len) { if (64 == block_len) {
...@@ -1019,7 +1014,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1019,7 +1014,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y,
epsilon, pow_2_scale); epsilon, pow_2_scale, noop_ptr);
break; break;
} }
} }
...@@ -1035,7 +1030,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1035,7 +1030,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y,
epsilon, pow_2_scale); epsilon, pow_2_scale, noop_ptr);
break; break;
} }
if (64 == block_len) { if (64 == block_len) {
...@@ -1048,7 +1043,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1048,7 +1043,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y,
epsilon, pow_2_scale); epsilon, pow_2_scale, noop_ptr);
break; break;
} }
} }
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include "common/common.h" #include "common/common.h"
#include "common/recipe/recipe_common.cuh" #include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h" #include "common/transpose/cast_transpose.h"
#include "common/util/cuda_runtime.h"
#include "common/utils.cuh" #include "common/utils.cuh"
namespace transformer_engine { namespace transformer_engine {
...@@ -190,7 +189,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -190,7 +189,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scaling) { const bool pow_2_scaling, const float* noop_ptr) {
// skip execution if noop
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE; bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE;
bool return_columnwise_gemm_ready = bool return_columnwise_gemm_ready =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
...@@ -252,14 +256,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -252,14 +256,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
__syncthreads(); __syncthreads();
// If not return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return rowwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (!return_columnwise_gemm_ready && !return_columnwise_compact) {
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// Step 2: Cast and store to output_c // Step 2: Cast and store to output_c
if (return_rowwise) { if (return_rowwise) {
constexpr int r_stride = constexpr int r_stride =
...@@ -365,14 +361,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -365,14 +361,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
} }
} }
// If return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return columnwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (return_columnwise_gemm_ready || return_columnwise_compact) {
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t // Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t
if (return_columnwise_gemm_ready) { if (return_columnwise_gemm_ready) {
constexpr int c_stride = constexpr int c_stride =
...@@ -1479,7 +1467,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1479,7 +1467,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
SimpleTensor& output_t, const float epsilon, SimpleTensor& output_t, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseRowwiseOption rowwise_option,
FP8BlockwiseColumnwiseOption columnwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow2_scale, cudaStream_t stream) { const bool pow2_scale, const SimpleTensor& noop_tensor,
cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise); NVTE_API_CALL(quantize_transpose_vector_blockwise);
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
...@@ -1532,14 +1521,19 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1532,14 +1521,19 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
"Input and output_t must have the same shape for columnwise non-transpose case."); "Input and output_t must have the same shape for columnwise non-transpose case.");
} }
} }
if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) {
NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype."); // output may not be defined if rowwise quantization is not needed.
NVTE_CHECK(output.dtype == output_t.dtype,
"output and output_t need to have the same dtype.");
}
NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2."); NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2.");
bool columnwise_compact = columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT; bool columnwise_compact = columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT;
size_t scale_t_k = scale_inv_t.shape[1]; size_t scale_t_k = scale_inv_t.shape[1];
scale_t_stride_x = columnwise_compact ? 1 : scale_t_k; scale_t_stride_x = columnwise_compact ? 1 : scale_t_k;
scale_t_stride_y = columnwise_compact ? scale_t_k : 1; scale_t_stride_y = columnwise_compact ? scale_t_k : 1;
} }
auto output_dtype =
rowwise_option != FP8BlockwiseRowwiseOption::NONE ? output.dtype : output_t.dtype;
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
const size_t block_len = blockwise_fp8_block_len(); const size_t block_len = blockwise_fp8_block_len();
...@@ -1548,10 +1542,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1548,10 +1542,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
#else #else
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim);
dim3 grid(num_blocks_x, num_blocks_y, 1);
cudaLaunchAttribute attribute[1]; const float* noop_ptr = reinterpret_cast<const float*>(noop_tensor.dptr);
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
#endif #endif
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
...@@ -1563,6 +1555,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1563,6 +1555,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
dim3 grid(num_blocks_x, num_blocks_y, 1); dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0; const bool full_tile = row_length % block_len == 0 && num_rows % block_len == 0;
#else #else
dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0;
#endif #endif
TRANSFORMER_ENGINE_SWITCH_CONDITION( TRANSFORMER_ENGINE_SWITCH_CONDITION(
...@@ -1575,12 +1568,12 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1575,12 +1568,12 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
size_t smem_bytes = kSMemSize_Rowwise * sizeof(InputType); size_t smem_bytes = kSMemSize_Rowwise * sizeof(InputType);
const size_t num_blocks_x = DIVUP(row_length, (size_t)(block_len)); const size_t num_blocks_x = DIVUP(row_length, (size_t)(block_len));
const size_t num_blocks_y = DIVUP(num_rows, (size_t)(block_len / 2)); const size_t num_blocks_y = DIVUP(num_rows, (size_t)(block_len / 2));
dim3 grid(num_blocks_x, num_blocks_y, 1);
if (smem_bytes >= 48 * 1024) { if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute( cudaError_t err = cudaFuncSetAttribute(
(const void*)&block_scaled_1d_cast_transpose_kernel_rowwise< (const void*)&block_scaled_1d_cast_transpose_kernel_rowwise<
kAligned, float, InputType, OutputType>, kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size.");
} }
block_scaled_1d_cast_transpose_kernel_rowwise<kAligned, float, InputType, block_scaled_1d_cast_transpose_kernel_rowwise<kAligned, float, InputType,
OutputType> OutputType>
...@@ -1588,17 +1581,19 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1588,17 +1581,19 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr), reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<float*>(scale_inv.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, epsilon, rowwise_option, pow2_scale); scale_stride_x, scale_stride_y, epsilon, rowwise_option, pow2_scale, noop_ptr);
} }
if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) { if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) {
size_t smem_bytes = kSMemSize_Colwise * sizeof(InputType); size_t smem_bytes = kSMemSize_Colwise * sizeof(InputType);
const size_t num_blocks_x = DIVUP(row_length, (size_t)(block_len / 2)); const size_t num_blocks_x = DIVUP(row_length, (size_t)(block_len / 2));
const size_t num_blocks_y = DIVUP(num_rows, (size_t)(block_len)); const size_t num_blocks_y = DIVUP(num_rows, (size_t)(block_len));
dim3 grid(num_blocks_x, num_blocks_y, 1); if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute( cudaError_t err = cudaFuncSetAttribute(
(const void*)&block_scaled_1d_cast_transpose_kernel_colwise< (const void*)&block_scaled_1d_cast_transpose_kernel_colwise<
kAligned, float, InputType, OutputType>, kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size.");
}
block_scaled_1d_cast_transpose_kernel_colwise<kAligned, float, InputType, block_scaled_1d_cast_transpose_kernel_colwise<kAligned, float, InputType,
OutputType> OutputType>
<<<grid, kThreadsPerBlock_Colwise, smem_bytes, stream>>>( <<<grid, kThreadsPerBlock_Colwise, smem_bytes, stream>>>(
...@@ -1606,7 +1601,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1606,7 +1601,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_t_stride_x, scale_t_stride_y, epsilon, columnwise_option, scale_t_stride_x, scale_t_stride_y, epsilon, columnwise_option,
pow2_scale); pow2_scale, noop_ptr);
} }
break; break;
} }
...@@ -1617,6 +1612,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1617,6 +1612,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
(const void*)&block_scaled_block_len64_1d_cast_transpose_kernel< (const void*)&block_scaled_block_len64_1d_cast_transpose_kernel<
kAligned, float, InputType, OutputType>, kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size.");
} }
block_scaled_block_len64_1d_cast_transpose_kernel<kAligned, float, InputType, block_scaled_block_len64_1d_cast_transpose_kernel<kAligned, float, InputType,
OutputType> OutputType>
...@@ -1627,36 +1623,27 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1627,36 +1623,27 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y,
epsilon, rowwise_option, columnwise_option, pow2_scale); epsilon, rowwise_option, columnwise_option, pow2_scale, noop_ptr);
break; break;
} }
} }
#else #else
size_t smem_bytes = kSMemSize * sizeof(InputType); size_t smem_bytes = kSMemSize * sizeof(InputType);
cudaLaunchConfig_t cfg = {grid, kThreadsPerBlock, smem_bytes, stream, NULL, 0};
if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >=
90) {
cfg.attrs = attribute;
cfg.numAttrs = 1;
}
// shared memory must be requested up // shared memory must be requested up
if (smem_bytes >= 48 * 1024) { if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute( cudaError_t err = cudaFuncSetAttribute(
&block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>, &block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size."); NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size.");
} cudaLaunchKernelEx(&cfg, } block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>
block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, <<<grid, kThreadsPerBlock, smem_bytes, stream>>>(
OutputType>,
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr), reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option,
scale_t_stride_y, epsilon, rowwise_option, columnwise_option, columnwise_option, pow2_scale, noop_ptr);
pow2_scale);
#endif #endif
) // kAligned ) // kAligned
) // OutputType ) // OutputType
......
...@@ -279,6 +279,7 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr ...@@ -279,6 +279,7 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
static_cast<const fp32 *>(noop.data.dptr), static_cast<const fp32 *>(noop.data.dptr),
static_cast<Type *>(output.data.dptr), static_cast<Type *>(output.data.dptr),
row_length, num_rows); row_length, num_rows);
NVTE_CHECK_CUDA(cudaGetLastError());
}); // NOLINT(*) }); // NOLINT(*)
} }
......
...@@ -420,6 +420,7 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt ...@@ -420,6 +420,7 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt
reinterpret_cast<BiasType *>(dbias->data.dptr), reinterpret_cast<BiasType *>(dbias->data.dptr),
reinterpret_cast<const fp32 *>(workspace.data.dptr), reduce_dbias_row_length, reinterpret_cast<const fp32 *>(workspace.data.dptr), reduce_dbias_row_length,
reduce_dbias_num_rows); reduce_dbias_num_rows);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor *dbias, void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor *dbias,
...@@ -490,17 +491,21 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor ...@@ -490,17 +491,21 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor
} }
#else #else
if (full_tile) { if (full_tile) {
cudaFuncSetAttribute(transpose_dbias_kernel<nvec_in, nvec_out, Param>, NVTE_CHECK_CUDA(cudaFuncSetAttribute(transpose_dbias_kernel<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); cudaFuncAttributePreferredSharedMemoryCarveout,
100));
transpose_dbias_kernel<nvec_in, nvec_out, Param> transpose_dbias_kernel<nvec_in, nvec_out, Param>
<<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>( <<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
param, row_length, num_rows, n_tiles); param, row_length, num_rows, n_tiles);
NVTE_CHECK_CUDA(cudaGetLastError());
} else { } else {
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>, cudaFuncSetAttribute(transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>,
cudaFuncAttributePreferredSharedMemoryCarveout, 100); cudaFuncAttributePreferredSharedMemoryCarveout, 100));
transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param> transpose_dbias_kernel_notaligned<nvec_in, nvec_out, Param>
<<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>( <<<n_blocks, cast_transpose_num_threads, shared_size_transpose, stream>>>(
param, row_length, num_rows, n_tiles); param, row_length, num_rows, n_tiles);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
#endif #endif
reduce_dbias<BiasType>(*workspace, dbias, row_length, num_rows, nvec_out, reduce_dbias<BiasType>(*workspace, dbias, row_length, num_rows, nvec_out,
......
...@@ -959,15 +959,16 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -959,15 +959,16 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) +
(out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT;
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>, cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType> cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>
<<<grid_dim, block_dim, shmem_size, stream>>>( <<<grid_dim, block_dim, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act,
tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows,
cols);); // NOLINT(*) cols);
NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
#endif #endif
} }
...@@ -1095,11 +1096,11 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1095,11 +1096,11 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
switch (scaling_type) { switch (scaling_type) {
case ScalingType::ROWWISE: case ScalingType::ROWWISE:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
OType, true, false, OType, true, false,
THREADS_PER_CHUNK_NON_COLWISE>, THREADS_PER_CHUNK_NON_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType, mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
true, false, THREADS_PER_CHUNK_NON_COLWISE> true, false, THREADS_PER_CHUNK_NON_COLWISE>
...@@ -1109,13 +1110,14 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1109,13 +1110,14 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise); scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
case ScalingType::COLWISE: case ScalingType::COLWISE:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
OType, false, true, OType, false, true,
THREADS_PER_CHUNK_COLWISE>, THREADS_PER_CHUNK_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType, mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
false, true, THREADS_PER_CHUNK_COLWISE> false, true, THREADS_PER_CHUNK_COLWISE>
...@@ -1125,13 +1127,14 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1125,13 +1127,14 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise); scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
case ScalingType::BIDIMENSIONAL: case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
OType, true, true, OType, true, true,
THREADS_PER_CHUNK_NON_COLWISE>, THREADS_PER_CHUNK_NON_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType, mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
true, true, THREADS_PER_CHUNK_NON_COLWISE> true, true, THREADS_PER_CHUNK_NON_COLWISE>
...@@ -1141,6 +1144,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1141,6 +1144,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise); scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
}); // NOLINT(*) }); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
......
...@@ -899,6 +899,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, ...@@ -899,6 +899,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows,
reduce_dbias_kernel<reduce_dbias_nvec, IType> reduce_dbias_kernel<reduce_dbias_nvec, IType>
<<<reduce_dbias_num_blocks, DBIAS_THREADS_PER_BLOCK, 0, stream>>>( <<<reduce_dbias_num_blocks, DBIAS_THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<IType *>(dbias->data.dptr), workspace_ptr, rows, cols); reinterpret_cast<IType *>(dbias->data.dptr), workspace_ptr, rows, cols);
NVTE_CHECK_CUDA(cudaGetLastError());
} }
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)> template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
...@@ -930,6 +931,7 @@ static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream ...@@ -930,6 +931,7 @@ static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream
cast_fp8_1D_kernel<IS_ACT, ParamOP, OP, IType, OType><<<grid, block, 0, stream>>>( cast_fp8_1D_kernel<IS_ACT, ParamOP, OP, IType, OType><<<grid, block, 0, stream>>>(
input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
} }
template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)> template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
...@@ -996,6 +998,7 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T ...@@ -996,6 +998,7 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T
<<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_act_input, tensor_map_output, <<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_act_input, tensor_map_output,
workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows,
cols); cols);
NVTE_CHECK_CUDA(cudaGetLastError());
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
...@@ -1134,55 +1137,51 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1134,55 +1137,51 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
cudaLaunchConfig_t cfg = {grid, block_size, dshmem_size, stream, NULL, 0};
// This kernel will only be called on sm100+, so no need to check sm_arch
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1; cfg.attrs = attribute;
cfg.numAttrs = 1;
switch (scaling_type) { switch (scaling_type) {
case ScalingType::ROWWISE: case ScalingType::ROWWISE:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));
cudaLaunchKernelEx(
&cfg,
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
case ScalingType::COLWISE: case ScalingType::COLWISE:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false, cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));
cudaLaunchKernelEx(
&cfg,
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false, cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
case ScalingType::BIDIMENSIONAL: case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute( NVTE_CHECK_CUDA(cudaFuncSetAttribute(
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));
cudaLaunchKernelEx( cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, true,
&cfg, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, <<<grid, block_size, dshmem_size, stream>>>(
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
} }
...@@ -1464,7 +1463,8 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ...@@ -1464,7 +1463,8 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
quantize_transpose_square_blockwise( quantize_transpose_square_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon, output_tensor->data, output_tensor->columnwise_data, epsilon,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales,
/*noop_tensor=*/noop_tensor.data, stream);
break; break;
} }
case NVTE_BLOCK_SCALING_1D: { case NVTE_BLOCK_SCALING_1D: {
...@@ -1492,10 +1492,10 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ...@@ -1492,10 +1492,10 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
} }
quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, quantize_transpose_vector_blockwise(
output_tensor->columnwise_scale_inv, output_tensor->data, input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->columnwise_data, epsilon, rowwise_option, output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, stream); columnwise_option, force_pow_2_scales, noop_tensor.data, stream);
break; break;
} }
default: default:
......
...@@ -336,6 +336,7 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s ...@@ -336,6 +336,7 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
#endif #endif
} }
} // namespace dequantization } // namespace dequantization
......
...@@ -23,8 +23,13 @@ ...@@ -23,8 +23,13 @@
#endif // __HIP_PLATFORM_AMD__ #endif // __HIP_PLATFORM_AMD__
#include <nvrtc.h> #include <nvrtc.h>
#ifdef NVTE_WITH_CUBLASMP
#include <cublasmp.h>
#endif // NVTE_WITH_CUBLASMP
#include <iostream> #include <iostream>
#include <stdexcept> #include <stdexcept>
#include <string>
#include "../util/string.h" #include "../util/string.h"
...@@ -130,4 +135,16 @@ ...@@ -130,4 +135,16 @@
} \ } \
} while (false) } while (false)
#ifdef NVTE_WITH_CUBLASMP
#define NVTE_CHECK_CUBLASMP(expr) \
do { \
const cublasMpStatus_t status = (expr); \
if (status != CUBLASMP_STATUS_SUCCESS) { \
NVTE_ERROR("cuBLASMp Error: ", std::to_string(status)); \
} \
} while (false)
#endif // NVTE_WITH_CUBLASMP
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ #endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
...@@ -248,6 +248,7 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o ...@@ -248,6 +248,7 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_padding_kernel<nvec, Type> multi_padding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
kernel_args.num_tensors = 0; kernel_args.num_tensors = 0;
} }
...@@ -277,6 +278,7 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o ...@@ -277,6 +278,7 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_padding_kernel<nvec, Type> multi_padding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -322,6 +324,7 @@ void multi_unpadding(const std::vector<Tensor*> input_list, std::vector<Tensor*> ...@@ -322,6 +324,7 @@ void multi_unpadding(const std::vector<Tensor*> input_list, std::vector<Tensor*>
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_unpadding_kernel<nvec, Type> multi_unpadding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
kernel_args.num_tensors = 0; kernel_args.num_tensors = 0;
} }
...@@ -349,6 +352,7 @@ void multi_unpadding(const std::vector<Tensor*> input_list, std::vector<Tensor*> ...@@ -349,6 +352,7 @@ void multi_unpadding(const std::vector<Tensor*> input_list, std::vector<Tensor*>
const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; const int n_blocks = kernel_args.block_range[kernel_args.num_tensors];
multi_unpadding_kernel<nvec, Type> multi_unpadding_kernel<nvec, Type>
<<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*) <<<n_blocks, threads_per_block, 0, stream>>>(kernel_args);); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
......
...@@ -371,6 +371,7 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out ...@@ -371,6 +371,7 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out
break; break;
} }
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -405,6 +406,7 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp ...@@ -405,6 +406,7 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp
break; break;
} }
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -498,6 +500,7 @@ void GatedActivationKernelLauncher(const InputType *input, OutputType *output, c ...@@ -498,6 +500,7 @@ void GatedActivationKernelLauncher(const InputType *input, OutputType *output, c
break; break;
} }
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
...@@ -609,6 +612,7 @@ void DGatedActivationKernelLauncher(const InputType *grad, const InputType *inpu ...@@ -609,6 +612,7 @@ void DGatedActivationKernelLauncher(const InputType *grad, const InputType *inpu
break; break;
} }
} }
NVTE_CHECK_CUDA(cudaGetLastError());
} }
} }
......
...@@ -199,6 +199,15 @@ STATS = { ...@@ -199,6 +199,15 @@ STATS = {
), ),
} }
FP8_NEGATIVE_ZERO = 128 # represnts -0.0 in fp8
def count_nonzero_fp8(fp8_data: torch.Tensor) -> torch.Tensor:
"""Count the number of non-zero elements in the fp8 data."""
fp8_data = fp8_data.view(dtype=torch.uint8)
zero_vals = torch.tensor([0, FP8_NEGATIVE_ZERO], device=fp8_data.device, dtype=torch.uint8)
return fp8_data.numel() - torch.isin(fp8_data, zero_vals).sum()
def add_underflows_stats(recipe_name: str, columnwise: bool = False): def add_underflows_stats(recipe_name: str, columnwise: bool = False):
"""Register *both* underflow stats (num and %) for the given recipe.""" """Register *both* underflow stats (num and %) for the given recipe."""
...@@ -212,22 +221,23 @@ def add_underflows_stats(recipe_name: str, columnwise: bool = False): ...@@ -212,22 +221,23 @@ def add_underflows_stats(recipe_name: str, columnwise: bool = False):
stats_to_num[stat_pct] = len(stats_to_num) stats_to_num[stat_pct] = len(stats_to_num)
STATS[stat_num] = ( STATS[stat_num] = (
lambda x, aux_dict: ( lambda x, aux_dict: x.count_nonzero()
- count_nonzero_fp8(
aux_dict[recipe_name].get_data_tensors( aux_dict[recipe_name].get_data_tensors(
rowwise_data=not columnwise, columnwise_data=columnwise rowwise_data=not columnwise, columnwise_data=columnwise
) )
== 0 ),
).sum()
- (x == 0).sum(),
lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)), lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)),
) )
STATS[stat_pct] = ( STATS[stat_pct] = (
lambda x, aux_dict: ( lambda x, aux_dict: (
x.count_nonzero()
- count_nonzero_fp8(
aux_dict[recipe_name].get_data_tensors( aux_dict[recipe_name].get_data_tensors(
rowwise_data=not columnwise, columnwise_data=columnwise rowwise_data=not columnwise, columnwise_data=columnwise
) )
== 0 )
).sum() )
/ aux_dict[recipe_name].numel() / aux_dict[recipe_name].numel()
* 100, * 100,
lambda buffers, _sn_num=stat_num: 100 lambda buffers, _sn_num=stat_num: 100
......
...@@ -38,19 +38,10 @@ from .quantize import fp8_autocast, update_collections, get_delayed_scaling ...@@ -38,19 +38,10 @@ from .quantize import fp8_autocast, update_collections, get_delayed_scaling
from .quantize import NVTE_FP8_COLLECTION_NAME from .quantize import NVTE_FP8_COLLECTION_NAME
from .sharding import MeshResource from .sharding import MeshResource
from .sharding import MajorShardingType, ShardingResource, ShardingType
from ..common.utils import deprecate_wrapper from ..common.utils import deprecate_wrapper
from ..common.utils import DeprecatedEnum from ..common.utils import DeprecatedEnum
MajorShardingType = DeprecatedEnum(
MajorShardingType, "MajorShardingType is deprecating in the near feature."
)
ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.")
ShardingResource = deprecate_wrapper(
ShardingResource,
"ShardingResource is renamed to MeshResource, and will be removed in the near feature.",
)
__all__ = [ __all__ = [
"NVTE_FP8_COLLECTION_NAME", "NVTE_FP8_COLLECTION_NAME",
...@@ -58,9 +49,6 @@ __all__ = [ ...@@ -58,9 +49,6 @@ __all__ = [
"update_collections", "update_collections",
"get_delayed_scaling", "get_delayed_scaling",
"MeshResource", "MeshResource",
"MajorShardingType",
"ShardingResource",
"ShardingType",
"flax", "flax",
"quantize", "quantize",
] ]
...@@ -14,7 +14,7 @@ import jax.numpy as jnp ...@@ -14,7 +14,7 @@ import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .quantize.tensor import ScaledTensor from .quantize.tensor import NoScaleTensor
from .quantize.quantizer import Quantizer from .quantize.quantizer import Quantizer
...@@ -22,7 +22,7 @@ def activation( ...@@ -22,7 +22,7 @@ def activation(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> jnp.ndarray:
"""Apply activation functions to input tensor with optional quantization. """Apply activation functions to input tensor with optional quantization.
This function applies a sequence of activation functions to the input tensor. This function applies a sequence of activation functions to the input tensor.
...@@ -72,7 +72,7 @@ def _activation_fwd_rule(x, activation_type, quantizer): ...@@ -72,7 +72,7 @@ def _activation_fwd_rule(x, activation_type, quantizer):
Tuple of (output, context) for backward pass Tuple of (output, context) for backward pass
""" """
fwd_output = tex.act_lu(x, activation_type, quantizer) fwd_output = tex.act_lu(x, activation_type, quantizer)
if isinstance(fwd_output, ScaledTensor): # This is a no-op for higher-precision tensors
fwd_output = fwd_output.dequantize() fwd_output = fwd_output.dequantize()
return fwd_output, (x, quantizer) return fwd_output, (x, quantizer)
...@@ -91,6 +91,10 @@ def _activation_bwd_rule(activation_type, ctx, g): ...@@ -91,6 +91,10 @@ def _activation_bwd_rule(activation_type, ctx, g):
(x, _) = ctx (x, _) = ctx
assert x.dtype == g.dtype assert x.dtype == g.dtype
dx = tex.dact_lu(g, x, activation_type) dx = tex.dact_lu(g, x, activation_type)
# No quantization is used in this VJP backward, so the output should
# always be a NoScaleTensor
assert isinstance(dx, NoScaleTensor)
dx = dx.data
return (dx, None) return (dx, None)
......
...@@ -29,7 +29,7 @@ from .misc import ( ...@@ -29,7 +29,7 @@ from .misc import (
) )
from .quantization import _jax_dbias, _quantize_dbias_impl from .quantization import _jax_dbias, _quantize_dbias_impl
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import ( from ..quantize import (
Quantizer, Quantizer,
QuantizeLayout, QuantizeLayout,
...@@ -922,7 +922,7 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): ...@@ -922,7 +922,7 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive):
"""Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]: def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]:
""" """
JAX native activation implementation JAX native activation implementation
""" """
...@@ -941,11 +941,11 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S ...@@ -941,11 +941,11 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S
x = jnp.squeeze(x, axis=-2) x = jnp.squeeze(x, axis=-2)
if quantizer: if quantizer:
return quantizer.quantize(x, flatten_axis=-1) return quantizer.quantize(x, flatten_axis=-1)
return x return NoScaleTensor(data=x, amax=None)
def _jax_quantize_dact_dbias( def _jax_quantize_dact_dbias(
dz: jnp.ndarray, dz: Union[jnp.ndarray, NoScaleTensor],
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
is_dbias: bool = True, is_dbias: bool = True,
...@@ -963,7 +963,9 @@ def _jax_quantize_dact_dbias( ...@@ -963,7 +963,9 @@ def _jax_quantize_dact_dbias(
_, vjp_func = jax.vjp( _, vjp_func = jax.vjp(
partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32)
) )
(dx,) = vjp_func(dz.astype(jnp.float32)) # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards.
dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None)
(dx,) = vjp_func(dz)
dbias = None dbias = None
if is_dbias: if is_dbias:
...@@ -973,6 +975,7 @@ def _jax_quantize_dact_dbias( ...@@ -973,6 +975,7 @@ def _jax_quantize_dact_dbias(
dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2) dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2)
else: else:
dx = dx.astype(x.dtype) dx = dx.astype(x.dtype)
dx = NoScaleTensor(data=dx, amax=None)
return dx, dbias return dx, dbias
...@@ -981,7 +984,6 @@ def act_lu( ...@@ -981,7 +984,6 @@ def act_lu(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
noop_scaled_tensor: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> Union[jnp.ndarray, ScaledTensor]:
"""Activation with optional quantization. """Activation with optional quantization.
...@@ -990,7 +992,6 @@ def act_lu( ...@@ -990,7 +992,6 @@ def act_lu(
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function to apply. activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
If quantizer is None: If quantizer is None:
...@@ -1035,16 +1036,16 @@ def act_lu( ...@@ -1035,16 +1036,16 @@ def act_lu(
is_outer=True, is_outer=True,
) )
out = out.reshape(output_shape) out = out.reshape(output_shape)
if noop_scaled_tensor: out = NoScaleTensor(
return ScaledTensorFactory.create_2x( data=out,
out, None, out, None, ScalingMode.NO_SCALING, dq_dtype=out.dtype amax=None,
) )
return out return out
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after. # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = act_lu( out = act_lu(
x=x.astype(jnp.float32), x=x,
activation_type=activation_type, activation_type=activation_type,
quantizer=None, quantizer=None,
) )
...@@ -1092,7 +1093,6 @@ def quantize_dact_dbias( ...@@ -1092,7 +1093,6 @@ def quantize_dact_dbias(
activation_type: Sequence[Union[str, Callable]] = ("gelu",), activation_type: Sequence[Union[str, Callable]] = ("gelu",),
is_dbias: bool = True, is_dbias: bool = True,
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
noop_scaled_tensor: bool = False,
) -> Tuple[ScaledTensor, jnp.ndarray]: ) -> Tuple[ScaledTensor, jnp.ndarray]:
"""Compute gradients of activation and bias with optional quantization. """Compute gradients of activation and bias with optional quantization.
...@@ -1103,7 +1103,6 @@ def quantize_dact_dbias( ...@@ -1103,7 +1103,6 @@ def quantize_dact_dbias(
activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
is_dbias: If True, compute bias gradient. Defaults to True. is_dbias: If True, compute bias gradient. Defaults to True.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
Tuple[ScaledTensor, jnp.ndarray]: A tuple containing: Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
...@@ -1146,19 +1145,10 @@ def quantize_dact_dbias( ...@@ -1146,19 +1145,10 @@ def quantize_dact_dbias(
if is_dbias: if is_dbias:
dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
if noop_scaled_tensor: output = NoScaleTensor(
return ( data=output,
ScaledTensorFactory.create_2x( amax=None,
output,
None,
output,
None,
ScalingMode.NO_SCALING,
dq_dtype=output.dtype,
),
dbias,
) )
return output, dbias return output, dbias
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
...@@ -1167,7 +1157,7 @@ def quantize_dact_dbias( ...@@ -1167,7 +1157,7 @@ def quantize_dact_dbias(
dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None
) )
return _quantize_dbias_impl( return _quantize_dbias_impl(
out, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
) )
is_gated = act_len == 2 is_gated = act_len == 2
...@@ -1188,13 +1178,13 @@ def quantize_dact_dbias( ...@@ -1188,13 +1178,13 @@ def quantize_dact_dbias(
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after. # Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = dact_lu( out = dact_lu(
dz=dz.astype(jnp.float32), dz=dz,
x=x.astype(jnp.float32), x=x,
activation_type=activation_type, activation_type=activation_type,
quantizer=None, quantizer=None,
) )
out, dbias = _quantize_dbias_impl( out, dbias = _quantize_dbias_impl(
out, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2
) )
return out, dbias return out, dbias
...@@ -1258,7 +1248,6 @@ def dact_lu( ...@@ -1258,7 +1248,6 @@ def dact_lu(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
noop_scale_tensor: bool = False,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> Union[jnp.ndarray, ScaledTensor]:
""" """
Backward pass for activation with optional quantization. Backward pass for activation with optional quantization.
...@@ -1268,7 +1257,6 @@ def dact_lu( ...@@ -1268,7 +1257,6 @@ def dact_lu(
x: Input tensor that was used in forward pass. x: Input tensor that was used in forward pass.
activation_type: Type of activation function that was applied. activation_type: Type of activation function that was applied.
quantizer: Optional quantizer for FP8 quantization of the output gradient. quantizer: Optional quantizer for FP8 quantization of the output gradient.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns: Returns:
The gradient of the activation with respect to the input. The gradient of the activation with respect to the input.
...@@ -1279,6 +1267,5 @@ def dact_lu( ...@@ -1279,6 +1267,5 @@ def dact_lu(
activation_type=activation_type, activation_type=activation_type,
is_dbias=False, is_dbias=False,
quantizer=quantizer, quantizer=quantizer,
noop_scaled_tensor=noop_scale_tensor,
) )
return output return output
...@@ -34,6 +34,7 @@ from .misc import ( ...@@ -34,6 +34,7 @@ from .misc import (
te_dtype_to_jax_dtype, te_dtype_to_jax_dtype,
get_padded_spec, get_padded_spec,
get_cudnn_version, get_cudnn_version,
get_all_device_compute_capability,
) )
from ..sharding import ( from ..sharding import (
global_mesh_resource, global_mesh_resource,
...@@ -2745,6 +2746,11 @@ def fused_attn_bwd( ...@@ -2745,6 +2746,11 @@ def fused_attn_bwd(
assert bias is None assert bias is None
bias = jnp.zeros(0, dtype=qkv[0].dtype) bias = jnp.zeros(0, dtype=qkv[0].dtype)
if 100 in get_all_device_compute_capability():
assert not (
attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0
), "For sm100, bprop kernel support for dropout + determinism (bias) is not supported"
fused_config = _FusedAttnConfig( fused_config = _FusedAttnConfig(
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
......
...@@ -134,6 +134,13 @@ class BasePrimitive(metaclass=ABCMeta): ...@@ -134,6 +134,13 @@ class BasePrimitive(metaclass=ABCMeta):
""" """
return NotImplemented return NotImplemented
@classmethod
def outer_impl(cls, *args, **kwargs):
"""
to describe implementation for outer primitive
"""
return cls.impl(*args, **kwargs)
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def batcher(): def batcher():
...@@ -196,7 +203,7 @@ def register_primitive(cls): ...@@ -196,7 +203,7 @@ def register_primitive(cls):
outer_p = core.Primitive(name_of_wrapper_p()) outer_p = core.Primitive(name_of_wrapper_p())
dispatch.prim_requires_devices_during_lowering.add(outer_p) dispatch.prim_requires_devices_during_lowering.add(outer_p)
outer_p.multiple_results = cls.multiple_results outer_p.multiple_results = cls.multiple_results
outer_p.def_impl(cls.impl) outer_p.def_impl(cls.outer_impl)
outer_p.def_abstract_eval(cls.outer_abstract) outer_p.def_abstract_eval(cls.outer_abstract)
batching.primitive_batchers[outer_p] = cls.batcher batching.primitive_batchers[outer_p] = cls.batcher
outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
...@@ -219,7 +226,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F ...@@ -219,7 +226,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F
""" """
Helper function to manage primitive states by name without modifying environment variables. Helper function to manage primitive states by name without modifying environment variables.
Allows enabling specific primitives, disabling specific primitives, or disabling all primitives. Allows enabling specific primitives, disabling specific primitives, or disabling all primitives.
This helper is used in the QuantizeConfig.initialize() methods. This helper is used in the get_quantize_config().initialize() methods.
Args: Args:
enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None. enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment