Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
27ddce40
Commit
27ddce40
authored
Oct 11, 2025
by
wenjh
Browse files
Merge branch 'nv_main'
parents
d262ef4c
5b3092a0
Changes
208
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
575 additions
and
169 deletions
+575
-169
transformer_engine/common/include/transformer_engine/fused_rope.h
...mer_engine/common/include/transformer_engine/fused_rope.h
+63
-0
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+6
-14
transformer_engine/common/include/transformer_engine/normalization.h
..._engine/common/include/transformer_engine/normalization.h
+30
-9
transformer_engine/common/include/transformer_engine/recipe.h
...sformer_engine/common/include/transformer_engine/recipe.h
+15
-0
transformer_engine/common/multi_tensor/l2norm.cu
transformer_engine/common/multi_tensor/l2norm.cu
+2
-0
transformer_engine/common/normalization/common.cpp
transformer_engine/common/normalization/common.cpp
+12
-7
transformer_engine/common/normalization/common.h
transformer_engine/common/normalization/common.h
+12
-7
transformer_engine/common/normalization/layernorm/ln_api.cpp
transformer_engine/common/normalization/layernorm/ln_api.cpp
+4
-2
transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
...common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
+19
-15
transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu
...gine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu
+17
-14
transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
...ormer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
+86
-2
transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh
...gine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh
+47
-8
transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
...mon/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
+137
-32
transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
...e/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
+17
-16
transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu
transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu
+9
-6
transformer_engine/common/permutation/permutation.cu
transformer_engine/common/permutation/permutation.cu
+6
-0
transformer_engine/common/recipe/current_scaling.cu
transformer_engine/common/recipe/current_scaling.cu
+57
-11
transformer_engine/common/recipe/fp8_block_scaling.cu
transformer_engine/common/recipe/fp8_block_scaling.cu
+2
-0
transformer_engine/common/swizzle/swizzle.cu
transformer_engine/common/swizzle/swizzle.cu
+32
-24
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+2
-2
No files found.
transformer_engine/common/include/transformer_engine/fused_rope.h
View file @
27ddce40
...
@@ -75,6 +75,69 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu
...
@@ -75,6 +75,69 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
cudaStream_t
stream
);
cudaStream_t
stream
);
/*! \brief Apply rotary positional embedding to the combined QKV input tensor.
*
* \param[in] qkv_input Combined QKV input tensor for fused rope.
* \param[in] q_freqs The freqs tensor for Q.
* \param[in] k_freqs The freqs tensor for K.
* \param[in] start_positions The beginning offsets for applying RoPE embeddings.
* \param[out] q_out Output tensor for Q.
* \param[out] k_out Output tensor for K.
* \param[out] v_out Output tensor for V.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] qkv_split_arg_list_0 The hidden size for Q.
* \param[in] qkv_split_arg_list_1 The hidden size for K.
* \param[in] qkv_split_arg_list_2 The hidden size for V.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_fused_qkv_rope_forward
(
const
NVTETensor
qkv_input
,
const
NVTETensor
q_freqs
,
const
NVTETensor
k_freqs
,
const
NVTETensor
start_positions
,
NVTETensor
q_out
,
NVTETensor
k_out
,
NVTETensor
v_out
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
qkv_split_arg_list_0
,
const
int
qkv_split_arg_list_1
,
const
int
qkv_split_arg_list_2
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the fused qkv rope.
*
* \param[in] q_grad_out Incoming gradient tensor for Q.
* \param[in] k_grad_out Incoming gradient tensor for K.
* \param[in] v_grad_out Incoming gradient tensor for V.
* \param[in] q_freqs The freqs tensor for Q.
* \param[in] k_freqs The freqs tensor for K.
* \param[out] qkv_grad_input Input gradient tensor to calculate.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] qkv_split_arg_list_0 The hidden size for Q.
* \param[in] qkv_split_arg_list_1 The hidden size for K.
* \param[in] qkv_split_arg_list_2 The hidden size for V.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_fused_qkv_rope_backward
(
const
NVTETensor
q_grad_out
,
const
NVTETensor
k_grad_out
,
const
NVTETensor
v_grad_out
,
const
NVTETensor
q_freqs
,
const
NVTETensor
k_freqs
,
NVTETensor
qkv_grad_input
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
qkv_split_arg_list_0
,
const
int
qkv_split_arg_list_1
,
const
int
qkv_split_arg_list_2
,
cudaStream_t
stream
);
#ifdef __cplusplus
#ifdef __cplusplus
}
// extern "C"
}
// extern "C"
#endif
#endif
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
27ddce40
...
@@ -133,21 +133,13 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
...
@@ -133,21 +133,13 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream to wait on.
* \param[in] stream CUDA stream to wait on.
*/
*/
void
nvte_multi_stream_cublas_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
void
nvte_multi_tensor_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
int
num_gemms
,
const
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
*
workspace
,
NVTETensor
*
workspace
,
bool
accumulate
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
);
cudaStream_t
stream
);
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
void
nvte_grouped_gemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
*
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
);
void
nvte_multi_stream_cublas_batchgemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
void
nvte_multi_stream_cublas_batchgemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
const
int
num_gemms
,
bool
transa
,
bool
transb
,
bool
grad
,
...
...
transformer_engine/common/include/transformer_engine/normalization.h
View file @
27ddce40
...
@@ -24,7 +24,7 @@ extern "C" {
...
@@ -24,7 +24,7 @@ extern "C" {
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}} \gamma + \beta
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}} \gamma + \beta
* @f]
* @f]
*
*
* Calling this function with workspace set to empty tensor will not perform the operation,
* Calling this function with workspace set to
an
empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
* but instead set the shape and type of the workspace tensor to the required values.
*
*
* \param[in] x Input tensor of shape [N, H].
* \param[in] x Input tensor of shape [N, H].
...
@@ -55,8 +55,8 @@ void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETe
...
@@ -55,8 +55,8 @@ void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETe
* else
* else
* with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$.
* with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$.
*
*
* Calling this function with workspace set to empty tensor will not perform the operation,
* Calling this function with workspace set to
an
empty tensor will not perform the operation,
* but instead set the shape and type of the
s
e tensor
s
to the required values.
* but instead set the shape and type of the
workspac
e tensor to the required values.
*
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
...
@@ -90,9 +90,8 @@ void nvte_layernorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETenso
...
@@ -90,9 +90,8 @@ void nvte_layernorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETenso
* RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon}
* RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon}
* @f]
* @f]
*
*
* Calling this function with workspace and barrier set to empty tensor will not
* Calling this function with workspace set to an empty tensor will not perform the operation,
* perform the operation, but instead set the shape and type of the workspace
* but instead set the shape and type of the workspace tensor to the required values.
* and barrier tensors to the required values.
*
*
* \param[in] x Input tensor of shape [N, H].
* \param[in] x Input tensor of shape [N, H].
* \param[in] gamma Gamma tensor of shape [H].
* \param[in] gamma Gamma tensor of shape [H].
...
@@ -121,9 +120,8 @@ void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float ep
...
@@ -121,9 +120,8 @@ void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float ep
* @f]
* @f]
* with respect to \f$x\f$ and \f$gamma\f$.
* with respect to \f$x\f$ and \f$gamma\f$.
*
*
* Calling this function with workspace, barrier, dgamma_part set
* Calling this function with workspace set to an empty tensor will not perform the operation,
* to empty tensor will not perform the operation, but instead set the shape and type
* but instead set the shape and type of the workspace tensor to the required values.
* of these tensors to the required values.
*
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
...
@@ -142,6 +140,29 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor
...
@@ -142,6 +140,29 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor
NVTETensor
workspace
,
const
int
multiprocessorCount
,
NVTETensor
workspace
,
const
int
multiprocessorCount
,
const
bool
zero_centered_gamma
,
cudaStream_t
stream
);
const
bool
zero_centered_gamma
,
cudaStream_t
stream
);
/*! \brief Compute backward of RMSNorm and add additional tensor to output gradient
*
* Calling this function with workspace set to an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
* \param[in] add Additional tensor to add to output gradient [N, H].
* \param[in] rsigma Reciprocal of the root mean square of the input
* calculated over the last dimension. Shape: [N].
* \param[in] gamma Gamma tensor of shape [H].
* \param[out] dx Output gradient of shape [N, H].
* \param[out] dgamma Gradient for gamma tensor of shape [H].
* \param[out] workspace Workspace tensor.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_rmsnorm_bwd_add
(
const
NVTETensor
dz
,
const
NVTETensor
x
,
const
NVTETensor
add
,
const
NVTETensor
rsigma
,
const
NVTETensor
gamma
,
NVTETensor
dx
,
NVTETensor
dgamma
,
NVTETensor
workspace
,
const
int
multiprocessorCount
,
const
bool
zero_centered_gamma
,
cudaStream_t
stream
);
/*! \brief Helper to enable cuDNN backend for normalization
/*! \brief Helper to enable cuDNN backend for normalization
*
*
* \param[in] bool Enable if True
* \param[in] bool Enable if True
...
...
transformer_engine/common/include/transformer_engine/recipe.h
View file @
27ddce40
...
@@ -86,6 +86,21 @@ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t s
...
@@ -86,6 +86,21 @@ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t s
void
nvte_compute_channel_colwise_amax
(
const
NVTETensor
input
,
NVTETensor
output
,
const
NVTETensor
fp8_scale
,
cudaStream_t
stream
);
void
nvte_compute_channel_colwise_amax
(
const
NVTETensor
input
,
NVTETensor
output
,
const
NVTETensor
fp8_scale
,
cudaStream_t
stream
);
/*! \brief Compute an FP8 tensor's amax with quantization config.
*
* The amax (maximum absolute value) of the input tensor is computed
* and written to the amax buffer of the output tensor, using the provided
* quantization configuration.
* One useful config is the noop tensor, which is needed by cuda graph.
*
* \param[in] input Input tensor. Must be unquantized.
* \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling.
* \param[in] config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_compute_amax_with_config
(
const
NVTETensor
input
,
NVTETensor
output
,
const
NVTEQuantizationConfig
config
,
cudaStream_t
stream
);
/*! \brief Update an FP8 tensor's scale based on its amax.
/*! \brief Update an FP8 tensor's scale based on its amax.
*
*
* This is only supported for FP8 tensors with per-tensor scaling.
* This is only supported for FP8 tensors with per-tensor scaling.
...
...
transformer_engine/common/multi_tensor/l2norm.cu
View file @
27ddce40
...
@@ -421,6 +421,7 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
...
@@ -421,6 +421,7 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
reinterpret_cast
<
float
*>
(
ret
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
ret
.
data
.
dptr
),
per_tensor
?
reinterpret_cast
<
float
*>
(
ret_per_tensor
.
data
.
dptr
)
:
nullptr
,
per_tensor
,
per_tensor
?
reinterpret_cast
<
float
*>
(
ret_per_tensor
.
data
.
dptr
)
:
nullptr
,
per_tensor
,
max_chunks_per_tensor
);
max_chunks_per_tensor
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
void
multi_tensor_unscale_l2norm_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
void
multi_tensor_unscale_l2norm_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
...
@@ -448,6 +449,7 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
...
@@ -448,6 +449,7 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
reinterpret_cast
<
float
*>
(
ret
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
ret
.
data
.
dptr
),
per_tensor
?
reinterpret_cast
<
float
*>
(
ret_per_tensor
.
data
.
dptr
)
:
nullptr
,
per_tensor
,
per_tensor
?
reinterpret_cast
<
float
*>
(
ret_per_tensor
.
data
.
dptr
)
:
nullptr
,
per_tensor
,
max_chunks_per_tensor
);
max_chunks_per_tensor
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
}
// namespace multi_tensor_l2norm
}
// namespace multi_tensor_l2norm
...
...
transformer_engine/common/normalization/common.cpp
View file @
27ddce40
...
@@ -140,8 +140,8 @@ void TeNormalizationPlan<KernelParamsType>::_set_workspace() {
...
@@ -140,8 +140,8 @@ void TeNormalizationPlan<KernelParamsType>::_set_workspace() {
if
(
_launch_params
.
barrier_bytes
>
0
)
{
if
(
_launch_params
.
barrier_bytes
>
0
)
{
_launch_params
.
params
.
barrier
=
_launch_params
.
params
.
barrier
=
reinterpret_cast
<
int
*>
(
workspace_dptr
+
_launch_params
.
workspace_bytes
);
reinterpret_cast
<
int
*>
(
workspace_dptr
+
_launch_params
.
workspace_bytes
);
cudaMemsetAsync
(
_launch_params
.
params
.
barrier
,
0
,
_launch_params
.
barrier_bytes
,
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
_launch_params
.
params
.
barrier
,
0
,
_launch_params
.
stream
);
_launch_params
.
barrier_bytes
,
_launch_params
.
stream
)
)
;
}
}
if
constexpr
(
std
::
is_same_v
<
KernelParamsType
,
BackwardKernelParams
>
)
{
if
constexpr
(
std
::
is_same_v
<
KernelParamsType
,
BackwardKernelParams
>
)
{
_launch_params
.
params
.
dgamma_part
=
_launch_params
.
params
.
dgamma_part
=
...
@@ -158,7 +158,7 @@ void TeNormalizationPlan<KernelParamsType>::_set_workspace() {
...
@@ -158,7 +158,7 @@ void TeNormalizationPlan<KernelParamsType>::_set_workspace() {
template
<
>
template
<
>
void
TeNormalizationPlan
<
ForwardKernelParams
>::
execute
(
void
*
x_dptr
,
void
*
gamma_dptr
,
void
TeNormalizationPlan
<
ForwardKernelParams
>::
execute
(
void
*
x_dptr
,
void
*
gamma_dptr
,
void
*
mean_dptr
,
void
*
rsigma_dptr
,
void
*
mean_dptr
,
void
*
rsigma_dptr
,
void
*
dx_dptr
,
void
*
dz_dptr
,
void
*
dx_dptr
,
void
*
dz_dptr
,
void
*
add_dptr
,
void
*
dbeta_dptr
,
void
*
dgamma_dptr
,
void
*
dbeta_dptr
,
void
*
dgamma_dptr
,
void
*
workspace_dptr
,
cudaStream_t
stream
)
{
void
*
workspace_dptr
,
cudaStream_t
stream
)
{
NVTE_ERROR
(
"Forward normalization should not call the backward execute function!"
);
NVTE_ERROR
(
"Forward normalization should not call the backward execute function!"
);
...
@@ -168,8 +168,9 @@ template <>
...
@@ -168,8 +168,9 @@ template <>
void
TeNormalizationPlan
<
BackwardKernelParams
>::
execute
(
void
*
x_dptr
,
void
*
gamma_dptr
,
void
TeNormalizationPlan
<
BackwardKernelParams
>::
execute
(
void
*
x_dptr
,
void
*
gamma_dptr
,
void
*
mean_dptr
,
void
*
rsigma_dptr
,
void
*
mean_dptr
,
void
*
rsigma_dptr
,
void
*
dx_dptr
,
void
*
dz_dptr
,
void
*
dx_dptr
,
void
*
dz_dptr
,
void
*
dbeta_dptr
,
void
*
dgamma_dptr
,
void
*
add_dptr
,
void
*
dbeta_dptr
,
void
*
workspace_dptr
,
cudaStream_t
stream
)
{
void
*
dgamma_dptr
,
void
*
workspace_dptr
,
cudaStream_t
stream
)
{
_launch_params
.
stream
=
stream
;
_launch_params
.
stream
=
stream
;
auto
&
kernel_params
=
_launch_params
.
params
;
auto
&
kernel_params
=
_launch_params
.
params
;
...
@@ -179,6 +180,7 @@ void TeNormalizationPlan<BackwardKernelParams>::execute(void* x_dptr, void* gamm
...
@@ -179,6 +180,7 @@ void TeNormalizationPlan<BackwardKernelParams>::execute(void* x_dptr, void* gamm
kernel_params
.
rs
=
rsigma_dptr
;
kernel_params
.
rs
=
rsigma_dptr
;
kernel_params
.
dx
=
dx_dptr
;
kernel_params
.
dx
=
dx_dptr
;
kernel_params
.
dz
=
dz_dptr
;
kernel_params
.
dz
=
dz_dptr
;
kernel_params
.
add
=
add_dptr
;
kernel_params
.
dgamma
=
dgamma_dptr
;
kernel_params
.
dgamma
=
dgamma_dptr
;
if
(
_is_layernorm
)
{
if
(
_is_layernorm
)
{
...
@@ -467,11 +469,14 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr,
...
@@ -467,11 +469,14 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr,
void
CudnnNormalizationPlan
::
execute
(
void
*
x_dptr
,
void
*
gamma_dptr
,
void
*
mean_dptr
,
void
CudnnNormalizationPlan
::
execute
(
void
*
x_dptr
,
void
*
gamma_dptr
,
void
*
mean_dptr
,
void
*
rsigma_dptr
,
void
*
dx_dptr
,
void
*
dz_dptr
,
void
*
rsigma_dptr
,
void
*
dx_dptr
,
void
*
dz_dptr
,
void
*
dbeta_dptr
,
void
*
dgamma_dptr
,
void
*
workspace_dptr
,
void
*
add_dptr
,
void
*
dbeta_dptr
,
void
*
dgamma_dptr
,
cudaStream_t
stream
)
{
void
*
workspace_dptr
,
cudaStream_t
stream
)
{
#ifdef USE_ROCM
#ifdef USE_ROCM
assert
(
false
);
assert
(
false
);
#else
#else
// cuDNN does not currently support fused backward+add
NVTE_CHECK
(
add_dptr
==
nullptr
);
// Binding data pointers to graph tensors
// Binding data pointers to graph tensors
_variant_pack
=
{
_variant_pack
=
{
{
_x
,
x_dptr
},
{
_rsigma
,
rsigma_dptr
},
{
_dz
,
dz_dptr
},
{
_dgamma
,
dgamma_dptr
},
{
_dx
,
dx_dptr
}};
{
_x
,
x_dptr
},
{
_rsigma
,
rsigma_dptr
},
{
_dz
,
dz_dptr
},
{
_dgamma
,
dgamma_dptr
},
{
_dx
,
dx_dptr
}};
...
...
transformer_engine/common/normalization/common.h
View file @
27ddce40
...
@@ -130,6 +130,9 @@ struct BackwardKernelParams : public KernelParamsBase {
...
@@ -130,6 +130,9 @@ struct BackwardKernelParams : public KernelParamsBase {
// Input: gradient wrt. LN FWD output.
// Input: gradient wrt. LN FWD output.
void
*
dz
;
void
*
dz
;
// Input: extra tensor to add for fused backward+add
void
*
add
;
// Workspace for Wgrad pre-reduction.
// Workspace for Wgrad pre-reduction.
void
*
dbeta_part
;
void
*
dbeta_part
;
void
*
dgamma_part
;
void
*
dgamma_part
;
...
@@ -141,8 +144,10 @@ struct BackwardKernelParams : public KernelParamsBase {
...
@@ -141,8 +144,10 @@ struct BackwardKernelParams : public KernelParamsBase {
void
*
dgamma
;
void
*
dgamma
;
};
};
using
BackwardAddKernelParams
=
BackwardKernelParams
;
enum
class
NVTE_Norm_Backend
{
Te
,
Cudnn
};
enum
class
NVTE_Norm_Backend
{
Te
,
Cudnn
};
enum
class
NVTE_Norm_Stage
{
Forward
,
Backward
};
enum
class
NVTE_Norm_Stage
{
Forward
,
Backward
,
BackwardAdd
};
using
TupleKeyType
=
std
::
tuple
<
uint64_t
,
uint64_t
,
uint64_t
,
bool
>
;
using
TupleKeyType
=
std
::
tuple
<
uint64_t
,
uint64_t
,
uint64_t
,
bool
>
;
struct
TupleHash
{
struct
TupleHash
{
...
@@ -225,8 +230,8 @@ class NormalizationPlanBase {
...
@@ -225,8 +230,8 @@ class NormalizationPlanBase {
cudaStream_t
stream
)
=
0
;
cudaStream_t
stream
)
=
0
;
virtual
void
execute
(
void
*
x_dptr
,
void
*
gamma_dptr
,
void
*
mean_dptr
,
void
*
rsigma_dptr
,
virtual
void
execute
(
void
*
x_dptr
,
void
*
gamma_dptr
,
void
*
mean_dptr
,
void
*
rsigma_dptr
,
void
*
dx_dptr
,
void
*
dz_dptr
,
void
*
dbeta
_dptr
,
void
*
d
gamm
a_dptr
,
void
*
dx_dptr
,
void
*
dz_dptr
,
void
*
add
_dptr
,
void
*
d
bet
a_dptr
,
void
*
workspace_dptr
,
cudaStream_t
stream
)
=
0
;
void
*
dgamma_dptr
,
void
*
workspace_dptr
,
cudaStream_t
stream
)
=
0
;
private:
private:
virtual
void
_build
()
=
0
;
virtual
void
_build
()
=
0
;
...
@@ -245,8 +250,8 @@ class TeNormalizationPlan : public NormalizationPlanBase {
...
@@ -245,8 +250,8 @@ class TeNormalizationPlan : public NormalizationPlanBase {
cudaStream_t
stream
)
override
;
cudaStream_t
stream
)
override
;
void
execute
(
void
*
x_dptr
,
void
*
gamma_dptr
,
void
*
mean_dptr
,
void
*
rsigma_dptr
,
void
*
dx_dptr
,
void
execute
(
void
*
x_dptr
,
void
*
gamma_dptr
,
void
*
mean_dptr
,
void
*
rsigma_dptr
,
void
*
dx_dptr
,
void
*
dz_dptr
,
void
*
dbeta_dptr
,
void
*
dgamma_dptr
,
void
*
workspace_dptr
,
void
*
dz_dptr
,
void
*
add_dptr
,
void
*
dbeta_dptr
,
void
*
dgamma_dptr
,
cudaStream_t
stream
)
override
;
void
*
workspace_dptr
,
cudaStream_t
stream
)
override
;
private:
private:
void
_set_workspace
();
void
_set_workspace
();
...
@@ -274,8 +279,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
...
@@ -274,8 +279,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {
cudaStream_t
stream
)
override
;
cudaStream_t
stream
)
override
;
void
execute
(
void
*
x_dptr
,
void
*
gamma_dptr
,
void
*
mean_dptr
,
void
*
rsigma_dptr
,
void
*
dx_dptr
,
void
execute
(
void
*
x_dptr
,
void
*
gamma_dptr
,
void
*
mean_dptr
,
void
*
rsigma_dptr
,
void
*
dx_dptr
,
void
*
dz_dptr
,
void
*
dbeta_dptr
,
void
*
dgamma_dptr
,
void
*
workspace_dptr
,
void
*
dz_dptr
,
void
*
add_dptr
,
void
*
dbeta_dptr
,
void
*
dgamma_dptr
,
cudaStream_t
stream
)
override
;
void
*
workspace_dptr
,
cudaStream_t
stream
)
override
;
private:
private:
void
_build
()
override
;
void
_build
()
override
;
...
...
transformer_engine/common/normalization/layernorm/ln_api.cpp
View file @
27ddce40
...
@@ -73,7 +73,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
...
@@ -73,7 +73,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
#endif
#endif
if
(
!
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
z
->
amax
.
dptr
!=
nullptr
)
{
if
(
!
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
z
->
amax
.
dptr
!=
nullptr
)
{
cudnn_backend
=
false
;
// cuDNN does not currently support amax output for non quantized output
NVTE_CHECK
(
!
cudnn_backend
,
"cuDNN does not currently support amax output for non quantized output"
);
}
}
bool
gamma_in_weight_dtype
=
false
;
bool
gamma_in_weight_dtype
=
false
;
...
@@ -192,7 +193,8 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
...
@@ -192,7 +193,8 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
}
else
{
}
else
{
NVTE_CHECK
(
workspace
->
data
.
shape
==
plan
->
getWorkspaceShape
());
NVTE_CHECK
(
workspace
->
data
.
shape
==
plan
->
getWorkspaceShape
());
plan
->
execute
(
x
.
data
.
dptr
,
gamma
.
data
.
dptr
,
mu
.
data
.
dptr
,
rsigma
.
data
.
dptr
,
dx
->
data
.
dptr
,
plan
->
execute
(
x
.
data
.
dptr
,
gamma
.
data
.
dptr
,
mu
.
data
.
dptr
,
rsigma
.
data
.
dptr
,
dx
->
data
.
dptr
,
dz
.
data
.
dptr
,
dbeta
->
data
.
dptr
,
dgamma
->
data
.
dptr
,
workspace
->
data
.
dptr
,
stream
);
dz
.
data
.
dptr
,
nullptr
/*add*/
,
dbeta
->
data
.
dptr
,
dgamma
->
data
.
dptr
,
workspace
->
data
.
dptr
,
stream
);
}
}
return
;
return
;
}
}
...
...
transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
View file @
27ddce40
...
@@ -14,16 +14,16 @@ using namespace transformer_engine::normalization;
...
@@ -14,16 +14,16 @@ using namespace transformer_engine::normalization;
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
CTAS_PER_ROW
,
int
WARPS_M
,
int
WARPS_N
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
CTAS_PER_ROW
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG_MAIN
,
int
BYTES_PER_LDG_FINAL
>
int
BYTES_PER_LDG_MAIN
,
int
BYTES_PER_LDG_FINAL
>
void
launch_tuned_
(
LaunchParams
<
BackwardKernelParams
>
&
launch_params
,
void
launch_
ln_bwd_
tuned_
(
LaunchParams
<
BackwardKernelParams
>
&
launch_params
,
const
bool
configure_params
)
{
// NOLINT(*)
const
bool
configure_params
)
{
// NOLINT(*)
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG_MAIN
>
;
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG_MAIN
>
;
auto
kernel
=
&
ln_bwd_tuned_kernel
<
Kernel_traits
>
;
auto
kernel
=
&
ln_bwd_tuned_kernel
<
Kernel_traits
>
;
if
(
configure_params
)
{
if
(
configure_params
)
{
int
ctas_per_sm
;
int
ctas_per_sm
;
cudaError
status_
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
NVTE_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
);
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
)
)
;
launch_params
.
params
.
ctas_per_row
=
CTAS_PER_ROW
;
launch_params
.
params
.
ctas_per_row
=
CTAS_PER_ROW
;
launch_params
.
params
.
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
=
launch_params
.
multiprocessorCount
*
ctas_per_sm
/
launch_params
.
params
.
ctas_per_row
;
launch_params
.
multiprocessorCount
*
ctas_per_sm
/
launch_params
.
params
.
ctas_per_row
;
...
@@ -57,13 +57,14 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
...
@@ -57,13 +57,14 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
if
(
ctas_per_row
==
1
)
{
if
(
ctas_per_row
==
1
)
{
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
,
stream
>>>
(
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
,
stream
>>>
(
launch_params
.
params
);
launch_params
.
params
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
else
{
}
else
{
dim3
grid
(
ctas_per_row
*
ctas_per_col
);
dim3
grid
(
ctas_per_row
*
ctas_per_col
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
void
*
params_
=
reinterpret_cast
<
void
*>
(
&
launch_params
.
params
);
void
*
params_
=
reinterpret_cast
<
void
*>
(
&
launch_params
.
params
);
cudaLaunchCooperativeKernel
(
reinterpret_cast
<
void
*>
(
kernel
),
grid
,
block
,
NVTE_CHECK_CUDA
(
cudaLaunchCooperativeKernel
(
reinterpret_cast
<
void
*>
(
kernel
),
grid
,
block
,
reinterpret_cast
<
void
**>
(
&
params_
),
Kernel_traits
::
SMEM_BYTES
,
reinterpret_cast
<
void
**>
(
&
params_
),
stream
);
Kernel_traits
::
SMEM_BYTES
,
stream
)
)
;
}
}
using
Kernel_traits_f
=
using
Kernel_traits_f
=
...
@@ -74,13 +75,14 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
...
@@ -74,13 +75,14 @@ void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
auto
kernel_f
=
&
ln_bwd_finalize_tuned_kernel
<
Kernel_traits_f
>
;
auto
kernel_f
=
&
ln_bwd_finalize_tuned_kernel
<
Kernel_traits_f
>
;
kernel_f
<<<
Kernel_traits_f
::
CTAS
,
Kernel_traits_f
::
THREADS_PER_CTA
,
0
,
stream
>>>
(
kernel_f
<<<
Kernel_traits_f
::
CTAS
,
Kernel_traits_f
::
THREADS_PER_CTA
,
0
,
stream
>>>
(
launch_params
.
params
);
launch_params
.
params
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG_MAIN
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG_MAIN
,
int
BYTES_PER_LDG_FINAL
>
int
BYTES_PER_LDG_FINAL
>
void
launch_general_
(
LaunchParams
<
BackwardKernelParams
>
&
launch_params
,
void
launch_
ln_bwd_
general_
(
LaunchParams
<
BackwardKernelParams
>
&
launch_params
,
const
bool
configure_params
)
{
// NOLINT(*)
const
bool
configure_params
)
{
// NOLINT(*)
auto
ceil_div
=
[](
int
x
,
int
y
)
->
int
{
return
(
x
+
y
-
1
)
/
y
;
};
auto
ceil_div
=
[](
int
x
,
int
y
)
->
int
{
return
(
x
+
y
-
1
)
/
y
;
};
// Instantiate kernel
// Instantiate kernel
...
@@ -95,8 +97,8 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
...
@@ -95,8 +97,8 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
int
ctas_per_row
=
launch_params
.
params
.
ctas_per_row
;
int
ctas_per_row
=
launch_params
.
params
.
ctas_per_row
;
if
(
configure_params
)
{
if
(
configure_params
)
{
int
ctas_per_sm
;
int
ctas_per_sm
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
NVTE_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
Kernel_traits
::
THREADS_PER_CTA
,
0
);
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
0
)
)
;
const
int
max_ctas
=
launch_params
.
multiprocessorCount
*
ctas_per_sm
;
const
int
max_ctas
=
launch_params
.
multiprocessorCount
*
ctas_per_sm
;
ctas_per_row
=
ceil_div
(
cols
,
HIDDEN_SIZE
);
ctas_per_row
=
ceil_div
(
cols
,
HIDDEN_SIZE
);
ctas_per_col
=
std
::
min
(
ceil_div
(
rows
,
WARPS_M
),
max_ctas
/
ctas_per_row
);
ctas_per_col
=
std
::
min
(
ceil_div
(
rows
,
WARPS_M
),
max_ctas
/
ctas_per_row
);
...
@@ -117,10 +119,11 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
...
@@ -117,10 +119,11 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
if
(
ctas_per_row
==
1
)
{
if
(
ctas_per_row
==
1
)
{
kernel
<<<
grid
,
block
,
0
,
stream
>>>
(
launch_params
.
params
);
kernel
<<<
grid
,
block
,
0
,
stream
>>>
(
launch_params
.
params
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
else
{
}
else
{
void
*
params_
=
reinterpret_cast
<
void
*>
(
&
launch_params
.
params
);
void
*
params_
=
reinterpret_cast
<
void
*>
(
&
launch_params
.
params
);
cudaLaunchCooperativeKernel
(
reinterpret_cast
<
void
*>
(
kernel
),
grid
,
block
,
NVTE_CHECK_CUDA
(
cudaLaunchCooperativeKernel
(
reinterpret_cast
<
void
*>
(
kernel
),
grid
,
block
,
reinterpret_cast
<
void
**>
(
&
params_
),
0
,
stream
);
reinterpret_cast
<
void
**>
(
&
params_
),
0
,
stream
)
)
;
}
}
// Launch finalization kernel
// Launch finalization kernel
...
@@ -134,6 +137,7 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
...
@@ -134,6 +137,7 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
dim3
block_final
(
Kernel_traits
::
THREADS_PER_WARP
*
WARPS_N_FINAL
,
WARPS_M_FINAL
);
dim3
block_final
(
Kernel_traits
::
THREADS_PER_WARP
*
WARPS_N_FINAL
,
WARPS_M_FINAL
);
dim3
grid_final
(
ceil_div
(
cols
,
ELTS_N_PER_CTA_FINAL
),
1
);
dim3
grid_final
(
ceil_div
(
cols
,
ELTS_N_PER_CTA_FINAL
),
1
);
kernel_final
<<<
grid_final
,
block_final
,
0
,
stream
>>>
(
launch_params
.
params
);
kernel_final
<<<
grid_final
,
block_final
,
0
,
stream
>>>
(
launch_params
.
params
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \
#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \
...
@@ -142,8 +146,8 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
...
@@ -142,8 +146,8 @@ void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
void \
void \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
launch_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE,
__VA_ARGS__>(
\
launch_
ln_bwd_
##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE,
\
launch_params, configure_params);
\
__VA_ARGS__>(launch_params, configure_params);
\
} \
} \
REGISTER_NORM_BASE( \
REGISTER_NORM_BASE( \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
...
...
transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu
View file @
27ddce40
...
@@ -13,15 +13,15 @@ using namespace transformer_engine::normalization;
...
@@ -13,15 +13,15 @@ using namespace transformer_engine::normalization;
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
CTAS_PER_ROW
,
int
WARPS_M
,
int
WARPS_N
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
CTAS_PER_ROW
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG
>
int
BYTES_PER_LDG
>
void
launch_tuned_
(
LaunchParams
<
ForwardKernelParams
>
&
launch_params
,
void
launch_
ln_fwd_
tuned_
(
LaunchParams
<
ForwardKernelParams
>
&
launch_params
,
const
bool
configure_params
)
{
// NOLINT(*)
const
bool
configure_params
)
{
// NOLINT(*)
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
>
;
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
>
;
auto
kernel
=
&
ln_fwd_tuned_kernel
<
Kernel_traits
>
;
auto
kernel
=
&
ln_fwd_tuned_kernel
<
Kernel_traits
>
;
if
(
configure_params
)
{
if
(
configure_params
)
{
int
ctas_per_sm
;
int
ctas_per_sm
;
cudaError
status_
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
NVTE_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
);
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
)
)
;
launch_params
.
params
.
ctas_per_row
=
CTAS_PER_ROW
;
launch_params
.
params
.
ctas_per_row
=
CTAS_PER_ROW
;
launch_params
.
params
.
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
=
launch_params
.
multiprocessorCount
*
ctas_per_sm
/
launch_params
.
params
.
ctas_per_row
;
launch_params
.
multiprocessorCount
*
ctas_per_sm
/
launch_params
.
params
.
ctas_per_row
;
...
@@ -53,19 +53,21 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
...
@@ -53,19 +53,21 @@ void launch_tuned_(LaunchParams<ForwardKernelParams> &launch_params,
if
(
ctas_per_row
==
1
)
{
if
(
ctas_per_row
==
1
)
{
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
>>>
(
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
>>>
(
launch_params
.
params
);
launch_params
.
params
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
else
{
}
else
{
dim3
grid
(
ctas_per_row
*
ctas_per_col
);
dim3
grid
(
ctas_per_row
*
ctas_per_col
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
void
*
params_
=
reinterpret_cast
<
void
*>
(
&
launch_params
.
params
);
void
*
params_
=
reinterpret_cast
<
void
*>
(
&
launch_params
.
params
);
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
// NOLINT(*)
NVTE_CHECK_CUDA
(
cudaLaunchCooperativeKernel
(
reinterpret_cast
<
void
*>
(
kernel
),
grid
,
block
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
);
reinterpret_cast
<
void
**>
(
&
params_
),
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
));
}
}
}
}
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG
>
typename
index_t
,
int
HIDDEN_SIZE
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG
>
void
launch_general_
(
LaunchParams
<
ForwardKernelParams
>
&
launch_params
,
void
launch_
ln_fwd_
general_
(
LaunchParams
<
ForwardKernelParams
>
&
launch_params
,
const
bool
configure_params
)
{
// NOLINT(*)
const
bool
configure_params
)
{
// NOLINT(*)
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
1
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
>
;
1
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
>
;
auto
kernel
=
&
ln_fwd_general_kernel
<
Kernel_traits
>
;
auto
kernel
=
&
ln_fwd_general_kernel
<
Kernel_traits
>
;
...
@@ -78,8 +80,8 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
...
@@ -78,8 +80,8 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
int
ctas_per_row
=
launch_params
.
params
.
ctas_per_row
;
int
ctas_per_row
=
launch_params
.
params
.
ctas_per_row
;
if
(
configure_params
)
{
if
(
configure_params
)
{
int
ctas_per_sm
;
int
ctas_per_sm
;
cudaError
status_
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
NVTE_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
0
);
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
0
)
)
;
const
int
max_ctas
=
launch_params
.
multiprocessorCount
*
ctas_per_sm
;
const
int
max_ctas
=
launch_params
.
multiprocessorCount
*
ctas_per_sm
;
ctas_per_row
=
ceil_div
(
cols
,
HIDDEN_SIZE
);
ctas_per_row
=
ceil_div
(
cols
,
HIDDEN_SIZE
);
ctas_per_col
=
std
::
min
(
ceil_div
(
rows
,
WARPS_M
),
max_ctas
/
ctas_per_row
);
ctas_per_col
=
std
::
min
(
ceil_div
(
rows
,
WARPS_M
),
max_ctas
/
ctas_per_row
);
...
@@ -99,10 +101,11 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
...
@@ -99,10 +101,11 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
if
(
ctas_per_row
==
1
)
{
if
(
ctas_per_row
==
1
)
{
kernel
<<<
grid
,
block
,
0
,
stream
>>>
(
launch_params
.
params
);
kernel
<<<
grid
,
block
,
0
,
stream
>>>
(
launch_params
.
params
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
else
{
}
else
{
void
*
params_
=
reinterpret_cast
<
void
*>
(
&
launch_params
.
params
);
void
*
params_
=
reinterpret_cast
<
void
*>
(
&
launch_params
.
params
);
cudaLaunchCooperativeKernel
(
reinterpret_cast
<
void
*>
(
kernel
),
grid
,
block
,
NVTE_CHECK_CUDA
(
cudaLaunchCooperativeKernel
(
reinterpret_cast
<
void
*>
(
kernel
),
grid
,
block
,
reinterpret_cast
<
void
**>
(
&
params_
),
0
,
stream
);
reinterpret_cast
<
void
**>
(
&
params_
),
0
,
stream
)
)
;
}
}
}
}
...
@@ -112,8 +115,8 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
...
@@ -112,8 +115,8 @@ void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
void \
void \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
launch_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE,
__VA_ARGS__>(
\
launch_
ln_fwd_
##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE,
\
launch_params, configure_params);
\
__VA_ARGS__>(launch_params, configure_params);
\
} \
} \
REGISTER_NORM_BASE( \
REGISTER_NORM_BASE( \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
...
...
transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
View file @
27ddce40
...
@@ -59,7 +59,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
...
@@ -59,7 +59,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
#endif
#endif
if
(
!
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
z
->
amax
.
dptr
!=
nullptr
)
{
if
(
!
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
z
->
amax
.
dptr
!=
nullptr
)
{
cudnn_backend
=
false
;
// cuDNN does not currently support amax output for non quantized output
NVTE_CHECK
(
!
cudnn_backend
,
"cuDNN does not currently support amax output for non quantized output"
);
}
}
bool
training
=
bool
training
=
...
@@ -169,7 +170,74 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
...
@@ -169,7 +170,74 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
}
else
{
}
else
{
NVTE_CHECK
(
workspace
->
data
.
shape
==
plan
->
getWorkspaceShape
());
NVTE_CHECK
(
workspace
->
data
.
shape
==
plan
->
getWorkspaceShape
());
plan
->
execute
(
x
.
data
.
dptr
,
gamma
.
data
.
dptr
,
nullptr
/*mu*/
,
rsigma
.
data
.
dptr
,
dx
->
data
.
dptr
,
plan
->
execute
(
x
.
data
.
dptr
,
gamma
.
data
.
dptr
,
nullptr
/*mu*/
,
rsigma
.
data
.
dptr
,
dx
->
data
.
dptr
,
dz
.
data
.
dptr
,
nullptr
/*dbeta*/
,
dgamma
->
data
.
dptr
,
workspace
->
data
.
dptr
,
stream
);
dz
.
data
.
dptr
,
nullptr
/*add*/
,
nullptr
/*dbeta*/
,
dgamma
->
data
.
dptr
,
workspace
->
data
.
dptr
,
stream
);
}
return
;
}
void
rmsnorm_bwd_add
(
const
Tensor
&
dz
,
const
Tensor
&
x
,
const
Tensor
&
add
,
const
Tensor
&
rsigma
,
const
Tensor
&
gamma
,
Tensor
*
dx
,
Tensor
*
dgamma
,
Tensor
*
workspace
,
const
int
multiprocessorCount
,
const
bool
zero_centered_gamma
,
cudaStream_t
stream
)
{
using
namespace
transformer_engine
;
NVTE_CHECK
(
dz
.
data
.
dtype
==
gamma
.
data
.
dtype
);
NVTE_CHECK
(
add
.
data
.
dtype
==
gamma
.
data
.
dtype
);
NVTE_CHECK
(
rsigma
.
data
.
dtype
==
DType
::
kFloat32
);
NVTE_CHECK
(
x
.
data
.
shape
.
size
()
==
2
);
NVTE_CHECK
(
dz
.
data
.
shape
==
x
.
data
.
shape
);
NVTE_CHECK
(
add
.
data
.
shape
==
x
.
data
.
shape
);
NVTE_CHECK
(
gamma
.
data
.
shape
[
0
]
==
x
.
data
.
shape
[
1
]);
NVTE_CHECK
(
dx
->
data
.
shape
==
x
.
data
.
shape
);
NVTE_CHECK
(
dx
->
data
.
dtype
==
x
.
data
.
dtype
);
NVTE_CHECK
(
dgamma
->
data
.
shape
==
gamma
.
data
.
shape
);
NVTE_CHECK
(
dgamma
->
data
.
dtype
==
gamma
.
data
.
dtype
);
if
(
!
workspace
->
data
.
shape
.
empty
())
{
CheckInputTensor
(
dz
,
"dz"
);
CheckInputTensor
(
x
,
"x"
);
CheckInputTensor
(
add
,
"add"
);
CheckInputTensor
(
rsigma
,
"rsigma"
);
CheckInputTensor
(
gamma
,
"gamma"
);
CheckOutputTensor
(
*
dx
,
"dx"
);
CheckOutputTensor
(
*
dgamma
,
"dgamma"
);
}
// cuDNN does not currently support fused backward+add
NVTE_Norm_Backend
norm_backend
=
NVTE_Norm_Backend
::
Te
;
// TE backend does not currently support zero_centered_gamma_in_weight_dtype
NVTE_CHECK
(
!
use_zero_centered_gamma_in_weight_dtype
(),
"zero_centered_gamma_in_weight_dtype is currently not supported for rmsnorm_bwd_add"
);
bool
is_aligned
=
is_ptr_aligned
(
x
.
data
.
dptr
,
gamma
.
data
.
dptr
,
rsigma
.
data
.
dptr
,
dx
->
data
.
dptr
,
dz
.
data
.
dptr
,
dgamma
->
data
.
dptr
,
add
.
data
.
dptr
);
bool
gamma_in_weight_dtype
=
false
;
auto
plan
=
NormalizationPlanRegistry
::
getInstance
().
getNormalizationPlan
(
norm_backend
,
NVTE_Norm_Type
::
RMSNorm
,
NVTE_Norm_Stage
::
BackwardAdd
,
gamma
.
data
.
dtype
,
// wtype
x
.
data
.
dtype
,
// itype
gamma
.
data
.
dtype
,
// otype
x
.
data
.
shape
[
0
],
// batch_size
x
.
data
.
shape
[
1
],
// hidden_size
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
,
NVTE_DELAYED_TENSOR_SCALING
,
true
,
gamma_in_weight_dtype
);
if
(
workspace
->
data
.
shape
.
empty
())
{
workspace
->
data
.
shape
=
plan
->
getWorkspaceShape
();
workspace
->
data
.
dtype
=
DType
::
kByte
;
return
;
}
else
{
NVTE_CHECK
(
workspace
->
data
.
shape
==
plan
->
getWorkspaceShape
());
plan
->
execute
(
x
.
data
.
dptr
,
gamma
.
data
.
dptr
,
nullptr
/*mu*/
,
rsigma
.
data
.
dptr
,
dx
->
data
.
dptr
,
dz
.
data
.
dptr
,
add
.
data
.
dptr
,
nullptr
/*dbeta*/
,
dgamma
->
data
.
dptr
,
workspace
->
data
.
dptr
,
stream
);
}
}
return
;
return
;
}
}
...
@@ -202,3 +270,19 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
...
@@ -202,3 +270,19 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size
convertNVTETensor
(
dx
),
convertNVTETensor
(
dgamma
),
convertNVTETensor
(
workspace
),
convertNVTETensor
(
dx
),
convertNVTETensor
(
dgamma
),
convertNVTETensor
(
workspace
),
multiprocessorCount
,
zero_centered_gamma
,
stream
);
multiprocessorCount
,
zero_centered_gamma
,
stream
);
}
}
void
nvte_rmsnorm_bwd_add
(
const
NVTETensor
dz
,
// Nxhidden_size
const
NVTETensor
x
,
// Nxhidden_size
const
NVTETensor
add
,
// Nxhidden_size
const
NVTETensor
rsigma
,
// N, FP32!
const
NVTETensor
gamma
,
// hidden_size
NVTETensor
dx
,
NVTETensor
dgamma
,
NVTETensor
workspace
,
const
int
multiprocessorCount
,
const
bool
zero_centered_gamma
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_rmsnorm_bwd_add
);
using
namespace
transformer_engine
;
rmsnorm_bwd_add
(
*
convertNVTETensorCheck
(
dz
),
*
convertNVTETensorCheck
(
x
),
*
convertNVTETensorCheck
(
add
),
*
convertNVTETensorCheck
(
rsigma
),
*
convertNVTETensorCheck
(
gamma
),
convertNVTETensor
(
dx
),
convertNVTETensor
(
dgamma
),
convertNVTETensor
(
workspace
),
multiprocessorCount
,
zero_centered_gamma
,
stream
);
}
transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh
View file @
27ddce40
...
@@ -7,13 +7,31 @@
...
@@ -7,13 +7,31 @@
#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_
#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_
#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_
#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_
#include <type_traits>
#include "../../utils.cuh"
#include "../../utils.cuh"
#include "../common.h"
#include "../common.h"
namespace
transformer_engine
{
namespace
transformer_engine
{
namespace
normalization
{
namespace
normalization
{
template
<
typename
Ktraits
>
struct
maybe_not_t
{};
template
<
typename
T
,
bool
Enabled
>
using
maybe_t
=
std
::
conditional_t
<
Enabled
,
T
,
maybe_not_t
>
;
template
<
typename
Ivec
,
typename
Ovec
,
bool
FusedAdd
>
union
dx_add_t
{
using
add_t
=
maybe_t
<
Ovec
,
FusedAdd
>
;
using
dx_t
=
Ivec
;
struct
{
char
_padding
[
sizeof
(
dx_t
)
>
sizeof
(
add_t
)
?
sizeof
(
dx_t
)
-
sizeof
(
add_t
)
:
0
];
add_t
add
;
};
dx_t
dx
;
};
template
<
typename
Ktraits
,
bool
FusedAdd
>
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
rmsnorm_bwd_tuned_kernel
(
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
rmsnorm_bwd_tuned_kernel
(
BackwardKernelParams
params
)
{
BackwardKernelParams
params
)
{
enum
{
ROWS_PER_CTA
=
Ktraits
::
ROWS_PER_CTA
};
enum
{
ROWS_PER_CTA
=
Ktraits
::
ROWS_PER_CTA
};
...
@@ -111,10 +129,19 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke
...
@@ -111,10 +129,19 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke
}
}
}
}
dx_add_t
<
Ivec
,
Ovec
,
FusedAdd
>
temp
[
LDGS
];
if
constexpr
(
FusedAdd
)
{
idx
=
row
*
Ktraits
::
VEC_COLS
+
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
temp
[
it
].
add
.
load_from
(
params
.
add
,
idx
);
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
}
reduce_t
result
=
reducer
.
allreduce
({
0
,
mdyy_local
},
sum
);
reduce_t
result
=
reducer
.
allreduce
({
0
,
mdyy_local
},
sum
);
mdyy_local
=
Get
<
1
>::
of
<
reduce_t
,
compute_t
>
(
result
)
*
rn
;
mdyy_local
=
Get
<
1
>::
of
<
reduce_t
,
compute_t
>
(
result
)
*
rn
;
Ivec
dx
[
LDGS
];
idx
=
row
*
Ktraits
::
VEC_COLS
+
c
;
idx
=
row
*
Ktraits
::
VEC_COLS
+
c
;
#pragma unroll
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
...
@@ -123,9 +150,13 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke
...
@@ -123,9 +150,13 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke
compute_t
dy_tmp
=
dy
[
it
*
NUM_ELTS
+
jt
];
compute_t
dy_tmp
=
dy
[
it
*
NUM_ELTS
+
jt
];
compute_t
y_tmp
=
y
[
it
*
NUM_ELTS
+
jt
];
compute_t
y_tmp
=
y
[
it
*
NUM_ELTS
+
jt
];
compute_t
dx_tmp
=
rs_r
*
(
dy_tmp
-
(
mdyy_local
*
y_tmp
));
compute_t
dx_tmp
=
rs_r
*
(
dy_tmp
-
(
mdyy_local
*
y_tmp
));
dx
[
it
].
data
.
elt
[
jt
]
=
dx_tmp
;
if
constexpr
(
FusedAdd
)
{
compute_t
add_tmp
=
temp
[
it
].
add
.
data
.
elt
[
jt
];
dx_tmp
+=
add_tmp
;
}
temp
[
it
].
dx
.
data
.
elt
[
jt
]
=
dx_tmp
;
}
}
dx
[
it
].
store_to
(
params
.
dx
,
idx
);
temp
[
it
].
dx
.
store_to
(
params
.
dx
,
idx
);
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
}
}
// end: grid stride loop
}
// end: grid stride loop
...
@@ -274,7 +305,7 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_fi
...
@@ -274,7 +305,7 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_fi
}
}
}
}
template
<
typename
Ktraits
>
template
<
typename
Ktraits
,
bool
FusedAdd
>
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
rmsnorm_bwd_general_kernel
(
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
rmsnorm_bwd_general_kernel
(
BackwardKernelParams
params
)
{
BackwardKernelParams
params
)
{
enum
{
LDGS
=
Ktraits
::
LDGS
};
enum
{
LDGS
=
Ktraits
::
LDGS
};
...
@@ -379,14 +410,22 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_
...
@@ -379,14 +410,22 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_
#pragma unroll
#pragma unroll
for
(
int
it
=
0
,
col
=
gidn
*
NUM_ELTS
;
it
<
LDGS
&&
row
<
params
.
rows
&&
col
<
params
.
cols
;
for
(
int
it
=
0
,
col
=
gidn
*
NUM_ELTS
;
it
<
LDGS
&&
row
<
params
.
rows
&&
col
<
params
.
cols
;
it
++
,
col
+=
gdimn
*
NUM_ELTS
)
{
it
++
,
col
+=
gdimn
*
NUM_ELTS
)
{
Ivec
dx
;
dx_add_t
<
Ivec
,
Ovec
,
FusedAdd
>
temp
;
if
constexpr
(
FusedAdd
)
{
temp
.
add
.
load_from_elts
(
params
.
add
,
row
*
params
.
cols
+
col
,
params
.
cols
-
col
);
}
#pragma unroll
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
dy_ij
=
dy
[
it
].
data
.
elt
[
jt
];
compute_t
dy_ij
=
dy
[
it
].
data
.
elt
[
jt
];
compute_t
y_ij
=
y
[
it
].
data
.
elt
[
jt
];
compute_t
y_ij
=
y
[
it
].
data
.
elt
[
jt
];
dx
.
data
.
elt
[
jt
]
=
rs
*
(
dy_ij
-
(
mdyy
*
y_ij
));
compute_t
dx_ij
=
rs
*
(
dy_ij
-
(
mdyy
*
y_ij
));
if
constexpr
(
FusedAdd
)
{
compute_t
add_ij
=
temp
.
add
.
data
.
elt
[
jt
];
dx_ij
+=
add_ij
;
}
temp
.
dx
.
data
.
elt
[
jt
]
=
dx_ij
;
}
}
dx
.
store_to_elts
(
params
.
dx
,
row
*
params
.
cols
+
col
,
params
.
cols
-
col
);
temp
.
dx
.
store_to_elts
(
params
.
dx
,
row
*
params
.
cols
+
col
,
params
.
cols
-
col
);
}
}
}
}
...
...
transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
View file @
27ddce40
...
@@ -12,18 +12,17 @@ using namespace transformer_engine::normalization;
...
@@ -12,18 +12,17 @@ using namespace transformer_engine::normalization;
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
CTAS_PER_ROW
,
int
WARPS_M
,
int
WARPS_N
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
CTAS_PER_ROW
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG_MAIN
,
int
BYTES_PER_LDG_FINAL
>
int
BYTES_PER_LDG_MAIN
,
int
BYTES_PER_LDG_FINAL
,
bool
FUSED_ADD
=
false
>
void
launch_tuned_
(
LaunchParams
<
BackwardKernelParams
>*
plaunch_params
,
void
launch_rmsnorm_bwd_tuned_
(
LaunchParams
<
BackwardKernelParams
>
&
launch_params
,
const
bool
configure_params
)
{
// NOLINT(*)
const
bool
configure_params
)
{
// NOLINT(*)
LaunchParams
<
BackwardKernelParams
>&
launch_params
=
*
plaunch_params
;
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG_MAIN
>
;
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG_MAIN
>
;
auto
kernel
=
&
rmsnorm_bwd_tuned_kernel
<
Kernel_traits
>
;
auto
kernel
=
&
rmsnorm_bwd_tuned_kernel
<
Kernel_traits
,
FUSED_ADD
>
;
if
(
configure_params
)
{
if
(
configure_params
)
{
int
ctas_per_sm
;
int
ctas_per_sm
=
0
;
cudaError
status_
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
NVTE_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
);
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
)
)
;
launch_params
.
params
.
ctas_per_row
=
CTAS_PER_ROW
;
launch_params
.
params
.
ctas_per_row
=
CTAS_PER_ROW
;
launch_params
.
params
.
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
=
launch_params
.
multiprocessorCount
*
ctas_per_sm
/
launch_params
.
params
.
ctas_per_row
;
launch_params
.
multiprocessorCount
*
ctas_per_sm
/
launch_params
.
params
.
ctas_per_row
;
...
@@ -37,19 +36,17 @@ void launch_tuned_(LaunchParams<BackwardKernelParams>* plaunch_params,
...
@@ -37,19 +36,17 @@ void launch_tuned_(LaunchParams<BackwardKernelParams>* plaunch_params,
launch_params
.
params
.
ctas_per_col
*
launch_params
.
params
.
cols
*
sizeof
(
compute_t
);
launch_params
.
params
.
ctas_per_col
*
launch_params
.
params
.
cols
*
sizeof
(
compute_t
);
return
;
return
;
}
}
#ifdef __HIP_PLATFORM_AMD__
#ifndef __HIP_PLATFORM_AMD__
if
(
Kernel_traits
::
SMEM_BYTES
>=
48
*
1024
)
{
if
(
Kernel_traits
::
SMEM_BYTES
>=
48
*
1024
)
{
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
(
const
void
*
)
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
Kernel_traits
::
SMEM_BYTES
));
}
}
#else
#else
if
(
Kernel_traits
::
SMEM_BYTES
>=
48
*
1024
)
{
if
(
Kernel_traits
::
SMEM_BYTES
>=
48
*
1024
)
{
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
(
const
void
*
)
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
Kernel_traits
::
SMEM_BYTES
));
}
}
#endif
#endif
auto
stream
=
launch_params
.
stream
;
auto
stream
=
launch_params
.
stream
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
auto
ctas_per_row
=
launch_params
.
params
.
ctas_per_row
;
auto
ctas_per_row
=
launch_params
.
params
.
ctas_per_row
;
...
@@ -57,13 +54,14 @@ void launch_tuned_(LaunchParams<BackwardKernelParams>* plaunch_params,
...
@@ -57,13 +54,14 @@ void launch_tuned_(LaunchParams<BackwardKernelParams>* plaunch_params,
if
(
ctas_per_row
==
1
)
{
if
(
ctas_per_row
==
1
)
{
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
,
stream
>>>
(
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
,
stream
>>>
(
launch_params
.
params
);
launch_params
.
params
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
else
{
}
else
{
dim3
grid
(
ctas_per_row
*
ctas_per_col
);
dim3
grid
(
ctas_per_row
*
ctas_per_col
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
void
*
params_
=
reinterpret_cast
<
void
*>
(
&
launch_params
.
params
);
void
*
params_
=
reinterpret_cast
<
void
*>
(
&
launch_params
.
params
);
cudaLaunchCooperativeKernel
(
reinterpret_cast
<
void
*>
(
kernel
),
grid
,
block
,
NVTE_CHECK_CUDA
(
cudaLaunchCooperativeKernel
(
reinterpret_cast
<
void
*>
(
kernel
),
grid
,
block
,
reinterpret_cast
<
void
**>
(
&
params_
),
Kernel_traits
::
SMEM_BYTES
,
reinterpret_cast
<
void
**>
(
&
params_
),
stream
);
Kernel_traits
::
SMEM_BYTES
,
stream
)
)
;
}
}
using
Kernel_traits_f
=
using
Kernel_traits_f
=
...
@@ -74,20 +72,20 @@ void launch_tuned_(LaunchParams<BackwardKernelParams>* plaunch_params,
...
@@ -74,20 +72,20 @@ void launch_tuned_(LaunchParams<BackwardKernelParams>* plaunch_params,
auto
kernel_f
=
&
rmsnorm_bwd_finalize_tuned_kernel
<
Kernel_traits_f
>
;
auto
kernel_f
=
&
rmsnorm_bwd_finalize_tuned_kernel
<
Kernel_traits_f
>
;
kernel_f
<<<
Kernel_traits_f
::
CTAS
,
Kernel_traits_f
::
THREADS_PER_CTA
,
0
,
stream
>>>
(
kernel_f
<<<
Kernel_traits_f
::
CTAS
,
Kernel_traits_f
::
THREADS_PER_CTA
,
0
,
stream
>>>
(
launch_params
.
params
);
launch_params
.
params
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG_MAIN
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG_MAIN
,
int
BYTES_PER_LDG_FINAL
>
int
BYTES_PER_LDG_FINAL
,
bool
FUSED_ADD
=
false
>
void
launch_general_
(
LaunchParams
<
BackwardKernelParams
>*
plaunch_params
,
void
launch_rmsnorm_bwd_general_
(
LaunchParams
<
BackwardKernelParams
>
&
launch_params
,
const
bool
configure_params
)
{
// NOLINT(*)
const
bool
configure_params
)
{
// NOLINT(*)
LaunchParams
<
BackwardKernelParams
>&
launch_params
=
*
plaunch_params
;
auto
ceil_div
=
[](
int
x
,
int
y
)
->
int
{
return
(
x
+
y
-
1
)
/
y
;
};
auto
ceil_div
=
[](
int
x
,
int
y
)
->
int
{
return
(
x
+
y
-
1
)
/
y
;
};
// Instantiate kernel
// Instantiate kernel
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
1
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG_MAIN
>
;
1
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG_MAIN
>
;
auto
kernel
=
&
rmsnorm_bwd_general_kernel
<
Kernel_traits
>
;
auto
kernel
=
&
rmsnorm_bwd_general_kernel
<
Kernel_traits
,
FUSED_ADD
>
;
// Configure kernel params
// Configure kernel params
const
int
rows
=
launch_params
.
params
.
rows
;
const
int
rows
=
launch_params
.
params
.
rows
;
...
@@ -95,9 +93,9 @@ void launch_general_(LaunchParams<BackwardKernelParams>* plaunch_params,
...
@@ -95,9 +93,9 @@ void launch_general_(LaunchParams<BackwardKernelParams>* plaunch_params,
int
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
int
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
int
ctas_per_row
=
launch_params
.
params
.
ctas_per_row
;
int
ctas_per_row
=
launch_params
.
params
.
ctas_per_row
;
if
(
configure_params
)
{
if
(
configure_params
)
{
int
ctas_per_sm
;
int
ctas_per_sm
=
0
;
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
NVTE_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
Kernel_traits
::
THREADS_PER_CTA
,
0
);
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
0
)
)
;
const
int
max_ctas
=
launch_params
.
multiprocessorCount
*
ctas_per_sm
;
const
int
max_ctas
=
launch_params
.
multiprocessorCount
*
ctas_per_sm
;
ctas_per_row
=
ceil_div
(
cols
,
HIDDEN_SIZE
);
ctas_per_row
=
ceil_div
(
cols
,
HIDDEN_SIZE
);
ctas_per_col
=
std
::
min
(
ceil_div
(
rows
,
WARPS_M
),
max_ctas
/
ctas_per_row
);
ctas_per_col
=
std
::
min
(
ceil_div
(
rows
,
WARPS_M
),
max_ctas
/
ctas_per_row
);
...
@@ -120,10 +118,11 @@ void launch_general_(LaunchParams<BackwardKernelParams>* plaunch_params,
...
@@ -120,10 +118,11 @@ void launch_general_(LaunchParams<BackwardKernelParams>* plaunch_params,
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
if
(
ctas_per_row
==
1
)
{
if
(
ctas_per_row
==
1
)
{
kernel
<<<
grid
,
block
,
0
,
stream
>>>
(
launch_params
.
params
);
kernel
<<<
grid
,
block
,
0
,
stream
>>>
(
launch_params
.
params
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
else
{
}
else
{
void
*
params_
=
reinterpret_cast
<
void
*>
(
&
launch_params
.
params
);
void
*
params_
=
reinterpret_cast
<
void
*>
(
&
launch_params
.
params
);
cudaLaunchCooperativeKernel
(
reinterpret_cast
<
void
*>
(
kernel
),
grid
,
block
,
NVTE_CHECK_CUDA
(
cudaLaunchCooperativeKernel
(
reinterpret_cast
<
void
*>
(
kernel
),
grid
,
block
,
reinterpret_cast
<
void
**>
(
&
params_
),
0
,
stream
);
reinterpret_cast
<
void
**>
(
&
params_
),
0
,
stream
)
)
;
}
}
// Launch finalization kernel
// Launch finalization kernel
...
@@ -137,6 +136,7 @@ void launch_general_(LaunchParams<BackwardKernelParams>* plaunch_params,
...
@@ -137,6 +136,7 @@ void launch_general_(LaunchParams<BackwardKernelParams>* plaunch_params,
dim3
block_final
(
Kernel_traits
::
THREADS_PER_WARP
*
WARPS_N_FINAL
,
WARPS_M_FINAL
);
dim3
block_final
(
Kernel_traits
::
THREADS_PER_WARP
*
WARPS_N_FINAL
,
WARPS_M_FINAL
);
dim3
grid_final
(
ceil_div
(
cols
,
ELTS_N_PER_CTA_FINAL
),
1
);
dim3
grid_final
(
ceil_div
(
cols
,
ELTS_N_PER_CTA_FINAL
),
1
);
kernel_final
<<<
grid_final
,
block_final
,
0
,
stream
>>>
(
launch_params
.
params
);
kernel_final
<<<
grid_final
,
block_final
,
0
,
stream
>>>
(
launch_params
.
params
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \
#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \
...
@@ -145,15 +145,15 @@ void launch_general_(LaunchParams<BackwardKernelParams>* plaunch_params,
...
@@ -145,15 +145,15 @@ void launch_general_(LaunchParams<BackwardKernelParams>* plaunch_params,
void \
void \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
launch_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE,
__VA_ARGS__>(
\
launch_
rmsnorm_bwd_
##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE,
\
&launch_params, configure_params);
\
__VA_ARGS__>(launch_params, configure_params);
\
} \
} \
REGISTER_NORM_BASE( \
REGISTER_NORM_BASE( \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \
} // namespace
} // namespace
// Create rmsnorm tuned launch function and register. Macro signature:
// Create rmsnorm
bwd
tuned launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ...
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
...
@@ -181,7 +181,7 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 1
...
@@ -181,7 +181,7 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 1
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
Backward
,
tuned
,
8192
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
Backward
,
tuned
,
8192
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
Backward
,
tuned
,
8192
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
Backward
,
tuned
,
8192
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
);
// Create rmsnorm general launch function and register. Macro signature:
// Create rmsnorm
bwd
general launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ...
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
...
@@ -214,3 +214,108 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp16, fp16, fp32,
...
@@ -214,3 +214,108 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp16, fp16, fp32,
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
Backward
,
general
,
4096
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
16
,
4
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
Backward
,
general
,
4096
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
16
,
4
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
Backward
,
general
,
4096
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
16
,
4
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
Backward
,
general
,
4096
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
16
,
4
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
Backward
,
general
,
4096
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
16
,
4
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
Backward
,
general
,
4096
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
16
,
4
);
// Create fused rmsnorm bwd + add tuned launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
512
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
512
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
512
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
768
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
768
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
768
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
1024
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
1024
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
1024
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
2048
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
2048
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
2048
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
4096
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
4096
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
4096
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
8192
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
8192
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
tuned
,
8192
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
,
4
,
true
);
// Create fused rmsnorm bwd + add general launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ...
// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
128
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
128
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
8
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
128
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
8
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
128
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
8
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
128
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
8
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
512
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
512
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
512
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
512
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
512
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
1024
,
fp32
,
fp32
,
fp32
,
fp32
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
1024
,
fp16
,
fp16
,
fp16
,
fp32
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
1024
,
fp16
,
fp32
,
fp16
,
fp32
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
1024
,
bf16
,
bf16
,
bf16
,
fp32
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
1024
,
bf16
,
fp32
,
bf16
,
fp32
,
4
,
1
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
2048
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
2048
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
2048
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
2048
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
2048
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
4096
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
4096
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
4096
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
4096
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
16
,
4
,
true
);
REGISTER_NORM_LAUNCHER
(
RMSNorm
,
BackwardAdd
,
general
,
4096
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
16
,
4
,
true
);
transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
View file @
27ddce40
...
@@ -13,17 +13,16 @@ using namespace transformer_engine::normalization;
...
@@ -13,17 +13,16 @@ using namespace transformer_engine::normalization;
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
CTAS_PER_ROW
,
int
WARPS_M
,
int
WARPS_N
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
CTAS_PER_ROW
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG
>
int
BYTES_PER_LDG
>
void
launch_tuned_
(
LaunchParams
<
ForwardKernelParams
>*
plaunch_params
,
void
launch_rmsnorm_fwd_tuned_
(
LaunchParams
<
ForwardKernelParams
>
&
launch_params
,
const
bool
configure_params
)
{
// NOLINT(*)
const
bool
configure_params
)
{
// NOLINT(*)
LaunchParams
<
ForwardKernelParams
>&
launch_params
=
*
plaunch_params
;
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
>
;
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
>
;
auto
kernel
=
&
rmsnorm_fwd_tuned_kernel
<
Kernel_traits
>
;
auto
kernel
=
&
rmsnorm_fwd_tuned_kernel
<
Kernel_traits
>
;
if
(
configure_params
)
{
if
(
configure_params
)
{
int
ctas_per_sm
;
int
ctas_per_sm
;
cudaError
status_
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
NVTE_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
);
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
)
)
;
launch_params
.
params
.
ctas_per_row
=
CTAS_PER_ROW
;
launch_params
.
params
.
ctas_per_row
=
CTAS_PER_ROW
;
launch_params
.
params
.
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
=
launch_params
.
multiprocessorCount
*
ctas_per_sm
/
launch_params
.
params
.
ctas_per_row
;
launch_params
.
multiprocessorCount
*
ctas_per_sm
/
launch_params
.
params
.
ctas_per_row
;
...
@@ -55,20 +54,21 @@ void launch_tuned_(LaunchParams<ForwardKernelParams>* plaunch_params,
...
@@ -55,20 +54,21 @@ void launch_tuned_(LaunchParams<ForwardKernelParams>* plaunch_params,
if
(
ctas_per_row
==
1
)
{
if
(
ctas_per_row
==
1
)
{
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
>>>
(
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
>>>
(
launch_params
.
params
);
launch_params
.
params
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
else
{
}
else
{
dim3
grid
(
ctas_per_row
*
ctas_per_col
);
dim3
grid
(
ctas_per_row
*
ctas_per_col
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
void
*
params_
=
reinterpret_cast
<
void
*>
(
&
launch_params
.
params
);
void
*
params_
=
reinterpret_cast
<
void
*>
(
&
launch_params
.
params
);
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
// NOLINT(*)
NVTE_CHECK_CUDA
(
cudaLaunchCooperativeKernel
(
reinterpret_cast
<
void
*>
(
kernel
),
grid
,
block
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
);
reinterpret_cast
<
void
**>
(
&
params_
),
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
));
}
}
}
}
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
template
<
typename
weight_t
,
typename
input_t
,
typename
output_t
,
typename
compute_t
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG
>
typename
index_t
,
int
HIDDEN_SIZE
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG
>
void
launch_general_
(
LaunchParams
<
ForwardKernelParams
>*
plaunch_params
,
void
launch_rmsnorm_fwd_general_
(
LaunchParams
<
ForwardKernelParams
>
&
launch_params
,
const
bool
configure_params
)
{
// NOLINT(*)
const
bool
configure_params
)
{
// NOLINT(*)
LaunchParams
<
ForwardKernelParams
>&
launch_params
=
*
plaunch_params
;
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
1
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
>
;
1
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
>
;
auto
kernel
=
&
rmsnorm_fwd_general_kernel
<
Kernel_traits
>
;
auto
kernel
=
&
rmsnorm_fwd_general_kernel
<
Kernel_traits
>
;
...
@@ -81,8 +81,8 @@ void launch_general_(LaunchParams<ForwardKernelParams>* plaunch_params,
...
@@ -81,8 +81,8 @@ void launch_general_(LaunchParams<ForwardKernelParams>* plaunch_params,
int
ctas_per_row
=
launch_params
.
params
.
ctas_per_row
;
int
ctas_per_row
=
launch_params
.
params
.
ctas_per_row
;
if
(
configure_params
)
{
if
(
configure_params
)
{
int
ctas_per_sm
;
int
ctas_per_sm
;
cudaError
status_
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
NVTE_CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
0
);
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
0
)
)
;
const
int
max_ctas
=
launch_params
.
multiprocessorCount
*
ctas_per_sm
;
const
int
max_ctas
=
launch_params
.
multiprocessorCount
*
ctas_per_sm
;
ctas_per_row
=
ceil_div
(
cols
,
HIDDEN_SIZE
);
ctas_per_row
=
ceil_div
(
cols
,
HIDDEN_SIZE
);
ctas_per_col
=
std
::
min
(
ceil_div
(
rows
,
WARPS_M
),
max_ctas
/
ctas_per_row
);
ctas_per_col
=
std
::
min
(
ceil_div
(
rows
,
WARPS_M
),
max_ctas
/
ctas_per_row
);
...
@@ -102,10 +102,11 @@ void launch_general_(LaunchParams<ForwardKernelParams>* plaunch_params,
...
@@ -102,10 +102,11 @@ void launch_general_(LaunchParams<ForwardKernelParams>* plaunch_params,
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
if
(
ctas_per_row
==
1
)
{
if
(
ctas_per_row
==
1
)
{
kernel
<<<
grid
,
block
,
0
,
stream
>>>
(
launch_params
.
params
);
kernel
<<<
grid
,
block
,
0
,
stream
>>>
(
launch_params
.
params
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
else
{
}
else
{
void
*
params_
=
reinterpret_cast
<
void
*>
(
&
launch_params
.
params
);
void
*
params_
=
reinterpret_cast
<
void
*>
(
&
launch_params
.
params
);
cudaLaunchCooperativeKernel
(
reinterpret_cast
<
void
*>
(
kernel
),
grid
,
block
,
NVTE_CHECK_CUDA
(
cudaLaunchCooperativeKernel
(
reinterpret_cast
<
void
*>
(
kernel
),
grid
,
block
,
reinterpret_cast
<
void
**>
(
&
params_
),
0
,
stream
);
reinterpret_cast
<
void
**>
(
&
params_
),
0
,
stream
)
)
;
}
}
}
}
...
@@ -115,8 +116,8 @@ void launch_general_(LaunchParams<ForwardKernelParams>* plaunch_params,
...
@@ -115,8 +116,8 @@ void launch_general_(LaunchParams<ForwardKernelParams>* plaunch_params,
void \
void \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
LaunchParams<NORM_STAGE##KernelParams> &launch_params, const bool configure_params) { \
launch_##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE,
__VA_ARGS__>(
\
launch_
rmsnorm_fwd_
##LAUNCH_TYPE##_<WTYPE, ITYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE,
\
&launch_params, configure_params);
\
__VA_ARGS__>(launch_params, configure_params);
\
} \
} \
REGISTER_NORM_BASE( \
REGISTER_NORM_BASE( \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \
...
...
transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu
View file @
27ddce40
...
@@ -35,17 +35,20 @@ void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t
...
@@ -35,17 +35,20 @@ void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t
switch
(
wait_kind
)
{
switch
(
wait_kind
)
{
case
WaitKind
::
KERNEL_WAIT
:
case
WaitKind
::
KERNEL_WAIT
:
wait_until_on_stream_and_reset
<<<
1
,
1
,
0
,
cur_stream
>>>
(
sig_addr
,
wait_value
,
signal_reset
);
wait_until_on_stream_and_reset
<<<
1
,
1
,
0
,
cur_stream
>>>
(
sig_addr
,
wait_value
,
signal_reset
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
break
;
break
;
case
WaitKind
::
NVSHMEM_WAIT
:
case
WaitKind
::
NVSHMEM_WAIT
:
nvshmemx_uint64_wait_until_on_stream
(
sig_addr
,
NVSHMEM_CMP_EQ
,
wait_value
,
cur_stream
);
nvshmemx_uint64_wait_until_on_stream
(
sig_addr
,
NVSHMEM_CMP_EQ
,
wait_value
,
cur_stream
);
cuStreamWriteValue64
((
CUstream
)
cur_stream
,
(
CUdeviceptr
)
sig_addr
,
(
cuuint64_t
)
signal_reset
,
NVTE_CHECK_CUDA_DRIVER
(
cuStreamWriteValue64
((
CUstream
)
cur_stream
,
(
CUdeviceptr
)
sig_addr
,
CU_STREAM_WRITE_VALUE_DEFAULT
);
(
cuuint64_t
)
signal_reset
,
CU_STREAM_WRITE_VALUE_DEFAULT
));
break
;
break
;
case
WaitKind
::
STREAM_WAIT
:
case
WaitKind
::
STREAM_WAIT
:
cuStreamWaitValue64
((
CUstream
)
cur_stream
,
(
CUdeviceptr
)
sig_addr
,
(
cuuint64_t
)
wait_value
,
NVTE_CHECK_CUDA_DRIVER
(
cuStreamWaitValue64
((
CUstream
)
cur_stream
,
(
CUdeviceptr
)
sig_addr
,
CU_STREAM_WAIT_VALUE_GEQ
);
(
cuuint64_t
)
wait_value
,
CU_STREAM_WAIT_VALUE_GEQ
));
cuStreamWriteValue64
((
CUstream
)
cur_stream
,
(
CUdeviceptr
)
sig_addr
,
(
cuuint64_t
)
signal_reset
,
NVTE_CHECK_CUDA_DRIVER
(
cuStreamWriteValue64
((
CUstream
)
cur_stream
,
(
CUdeviceptr
)
sig_addr
,
CU_STREAM_WRITE_VALUE_DEFAULT
);
(
cuuint64_t
)
signal_reset
,
CU_STREAM_WRITE_VALUE_DEFAULT
));
break
;
break
;
}
}
}
}
transformer_engine/common/permutation/permutation.cu
View file @
27ddce40
...
@@ -251,6 +251,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
...
@@ -251,6 +251,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
moe_permute_row_map
<<<
blocks
,
threads
,
0
,
stream
>>>
(
sorted_row_id
,
row_id_map
,
num_rows
,
topK
,
moe_permute_row_map
<<<
blocks
,
threads
,
0
,
stream
>>>
(
sorted_row_id
,
row_id_map
,
num_rows
,
topK
,
num_out_tokens
);
num_out_tokens
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
blocks
=
num_rows
;
blocks
=
num_rows
;
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
...
@@ -260,6 +261,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
...
@@ -260,6 +261,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
#endif
#endif
moe_permute_kernel
<
T
,
TCompute
,
128
,
false
><<<
blocks
,
threads
,
0
,
stream
>>>
(
moe_permute_kernel
<
T
,
TCompute
,
128
,
false
><<<
blocks
,
threads
,
0
,
stream
>>>
(
input
,
nullptr
,
output
,
nullptr
,
nullptr
,
row_id_map
,
num_rows
,
topK
,
num_cols
);
input
,
nullptr
,
output
,
nullptr
,
nullptr
,
row_id_map
,
num_rows
,
topK
,
num_cols
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
else
{
}
else
{
// moe_unpermute_bwd
// moe_unpermute_bwd
...
@@ -271,6 +273,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
...
@@ -271,6 +273,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
moe_permute_kernel
<
T
,
TCompute
,
1
,
false
><<<
blocks
,
threads
,
0
,
stream
>>>
(
moe_permute_kernel
<
T
,
TCompute
,
1
,
false
><<<
blocks
,
threads
,
0
,
stream
>>>
(
input
,
input_fwd
,
output
,
nullptr
,
nullptr
,
row_id_map
,
num_rows
,
topK
,
num_cols
);
input
,
input_fwd
,
output
,
nullptr
,
nullptr
,
row_id_map
,
num_rows
,
topK
,
num_cols
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
else
{
}
else
{
// moe_unpermute_bwd with probs
// moe_unpermute_bwd with probs
...
@@ -294,6 +297,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
...
@@ -294,6 +297,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
}
else
{
}
else
{
NVTE_ERROR
(
"topK cannot exceed 128."
);
NVTE_ERROR
(
"topK cannot exceed 128."
);
}
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
}
}
}
}
...
@@ -322,11 +326,13 @@ void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const f
...
@@ -322,11 +326,13 @@ void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const f
moe_unpermute_kernel
<
T
,
TCompute
,
false
><<<
blocks
,
threads
,
smem_bytes
,
stream
>>>
(
moe_unpermute_kernel
<
T
,
TCompute
,
false
><<<
blocks
,
threads
,
smem_bytes
,
stream
>>>
(
input
,
output
,
row_id_map
,
nullptr
,
num_rows
,
topK
,
num_cols
);
input
,
output
,
row_id_map
,
nullptr
,
num_rows
,
topK
,
num_cols
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
else
{
}
else
{
// moe_unpermute_fwd with probs
// moe_unpermute_fwd with probs
moe_unpermute_kernel
<
T
,
TCompute
,
true
><<<
blocks
,
threads
,
smem_bytes
,
stream
>>>
(
moe_unpermute_kernel
<
T
,
TCompute
,
true
><<<
blocks
,
threads
,
smem_bytes
,
stream
>>>
(
input
,
output
,
row_id_map
,
prob
,
num_rows
,
topK
,
num_cols
);
input
,
output
,
row_id_map
,
prob
,
num_rows
,
topK
,
num_cols
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
}
}
...
...
transformer_engine/common/recipe/current_scaling.cu
View file @
27ddce40
...
@@ -30,7 +30,11 @@ constexpr int amax_kernel_threads = 512;
...
@@ -30,7 +30,11 @@ constexpr int amax_kernel_threads = 512;
template
<
int
nvec
,
bool
aligned
,
typename
InputType
>
template
<
int
nvec
,
bool
aligned
,
typename
InputType
>
__launch_bounds__
(
amax_kernel_threads
)
__global__
__launch_bounds__
(
amax_kernel_threads
)
__global__
void
amax_kernel
(
const
InputType
*
input
,
float
*
amax
,
const
size_t
N
,
void
amax_kernel
(
const
InputType
*
input
,
float
*
amax
,
const
size_t
N
,
const
size_t
num_aligned_elements
)
{
const
size_t
num_aligned_elements
,
const
float
*
noop_ptr
)
{
if
(
noop_ptr
!=
nullptr
&&
noop_ptr
[
0
]
==
1.0
f
)
{
return
;
}
VectorizedLoader
<
InputType
,
nvec
,
aligned
>
loader
(
input
,
N
);
VectorizedLoader
<
InputType
,
nvec
,
aligned
>
loader
(
input
,
N
);
InputType
max
=
0.
f
;
InputType
max
=
0.
f
;
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
const
int
warp_id
=
threadIdx
.
x
/
THREADS_PER_WARP
;
...
@@ -124,9 +128,10 @@ void channel_colwise_amax_kernel_v2(const InputType* in, float* out, const float
...
@@ -124,9 +128,10 @@ void channel_colwise_amax_kernel_v2(const InputType* in, float* out, const float
}
}
template
<
int
nvec
,
typename
InputType
>
template
<
int
nvec
,
typename
InputType
>
void
launch_amax_kernel
(
const
InputType
*
input
,
float
*
amax
,
const
size_t
N
,
cudaStream_t
stream
)
{
void
launch_amax_kernel
(
const
InputType
*
input
,
float
*
amax
,
const
size_t
N
,
const
float
*
noop_ptr
,
cudaStream_t
stream
)
{
// Zero out amax so we can update with atomic max
// Zero out amax so we can update with atomic max
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
);
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
)
)
;
// Return immediately if tensor is empty
// Return immediately if tensor is empty
if
(
N
==
0
)
{
if
(
N
==
0
)
{
...
@@ -147,16 +152,17 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
...
@@ -147,16 +152,17 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
switch
(
align
)
{
switch
(
align
)
{
case
Alignment
::
SAME_ALIGNED
:
case
Alignment
::
SAME_ALIGNED
:
amax_kernel
<
nvec
,
true
,
InputType
>
amax_kernel
<
nvec
,
true
,
InputType
>
<<<
num_blocks
,
threads
,
0
,
stream
>>>
(
input
,
amax
,
N
,
num_aligned_elements
);
<<<
num_blocks
,
threads
,
0
,
stream
>>>
(
input
,
amax
,
N
,
num_aligned_elements
,
noop_ptr
);
break
;
break
;
case
Alignment
::
SAME_UNALIGNED
:
case
Alignment
::
SAME_UNALIGNED
:
amax_kernel
<
nvec
,
false
,
InputType
>
amax_kernel
<
nvec
,
false
,
InputType
>
<<<
num_blocks
,
threads
,
0
,
stream
>>>
(
input
,
amax
,
N
,
num_aligned_elements
);
<<<
num_blocks
,
threads
,
0
,
stream
>>>
(
input
,
amax
,
N
,
num_aligned_elements
,
noop_ptr
);
break
;
break
;
case
Alignment
::
DIFFERENT
:
{
case
Alignment
::
DIFFERENT
:
{
// This case is a logic error, since there is only one pointer (input)
// This case is a logic error, since there is only one pointer (input)
// in the alignment check. Still safe to process without vectorization.
// in the alignment check. Still safe to process without vectorization.
amax_kernel
<
1
,
true
,
InputType
><<<
num_blocks
,
threads
,
0
,
stream
>>>
(
input
,
amax
,
N
,
N
);
amax_kernel
<
1
,
true
,
InputType
>
<<<
num_blocks
,
threads
,
0
,
stream
>>>
(
input
,
amax
,
N
,
N
,
noop_ptr
);
break
;
break
;
}
}
}
}
...
@@ -188,8 +194,10 @@ void launch_channel_colwise_amax_kernel(const InputType *input, float *amax, con
...
@@ -188,8 +194,10 @@ void launch_channel_colwise_amax_kernel(const InputType *input, float *amax, con
}
// namespace
}
// namespace
}
// namespace transformer_engine
}
// namespace transformer_engine
void
nvte_compute_amax
(
const
NVTETensor
input_
,
const
NVTETensor
output_
,
cudaStream_t
stream
)
{
namespace
{
NVTE_API_CALL
(
nvte_compute_amax
);
void
compute_amax_impl
(
const
NVTETensor
input_
,
const
NVTETensor
output_
,
cudaStream_t
stream
,
const
NVTEQuantizationConfig
config_
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
// Check input tensor
// Check input tensor
...
@@ -224,12 +232,35 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
...
@@ -224,12 +232,35 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
to_string
(
output
.
amax
.
dtype
),
")"
);
to_string
(
output
.
amax
.
dtype
),
")"
);
CheckOutputTensor
(
output
,
"output_compute_amax"
,
true
);
CheckOutputTensor
(
output
,
"output_compute_amax"
,
true
);
float
*
noop_ptr
=
nullptr
;
if
(
config_
!=
nullptr
)
{
const
QuantizationConfig
*
config_cpp
=
reinterpret_cast
<
const
QuantizationConfig
*>
(
config_
);
// extract noop tensor from quant_config_cpp if it's not null
const
NVTETensor
noop
=
config_cpp
?
config_cpp
->
noop_tensor
:
nullptr
;
noop_ptr
=
reinterpret_cast
<
float
*>
(
(
noop
!=
nullptr
?
convertNVTETensorCheck
(
noop
)
->
data
.
dptr
:
nullptr
));
}
// Compute amax
// Compute amax
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
data
.
dtype
,
IType
,
constexpr
int
nvec
=
32
/
sizeof
(
IType
);
input
.
data
.
dtype
,
IType
,
constexpr
int
nvec
=
32
/
sizeof
(
IType
);
launch_amax_kernel
<
nvec
>
(
reinterpret_cast
<
const
IType
*>
(
input
.
data
.
dptr
),
launch_amax_kernel
<
nvec
>
(
reinterpret_cast
<
const
IType
*>
(
input
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
output
.
amax
.
dptr
),
input
.
data
.
numel
(),
reinterpret_cast
<
float
*>
(
output
.
amax
.
dptr
),
input
.
data
.
numel
(),
stream
););
// NOLINT(*)
noop_ptr
,
stream
););
// NOLINT(*)
}
}
// anonymous namespace
void
nvte_compute_amax
(
const
NVTETensor
input_
,
const
NVTETensor
output_
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_compute_amax
);
compute_amax_impl
(
input_
,
output_
,
stream
,
nullptr
);
}
void
nvte_compute_amax_with_config
(
const
NVTETensor
input_
,
const
NVTETensor
output_
,
const
NVTEQuantizationConfig
config_
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_compute_amax_with_config
);
compute_amax_impl
(
input_
,
output_
,
stream
,
config_
);
}
}
void
nvte_compute_channel_colwise_amax
(
const
NVTETensor
input_
,
const
NVTETensor
output_
,
const
NVTETensor
fp8_scale_
,
cudaStream_t
stream
)
{
void
nvte_compute_channel_colwise_amax
(
const
NVTETensor
input_
,
const
NVTETensor
output_
,
const
NVTETensor
fp8_scale_
,
cudaStream_t
stream
)
{
...
@@ -271,7 +302,11 @@ namespace {
...
@@ -271,7 +302,11 @@ namespace {
__global__
void
compute_scale_from_amax_kernel
(
const
float
*
amax_ptr
,
float
*
scale_ptr
,
__global__
void
compute_scale_from_amax_kernel
(
const
float
*
amax_ptr
,
float
*
scale_ptr
,
const
float
max_fp8
,
const
bool
force_pow_2_scales
,
const
float
max_fp8
,
const
bool
force_pow_2_scales
,
const
float
epsilon
)
{
const
float
epsilon
,
const
float
*
noop_ptr
)
{
if
(
noop_ptr
!=
nullptr
&&
noop_ptr
[
0
]
==
1.0
f
)
{
return
;
}
*
scale_ptr
=
compute_scale_from_amax
(
*
amax_ptr
,
max_fp8
,
force_pow_2_scales
,
epsilon
,
*
scale_ptr
=
compute_scale_from_amax
(
*
amax_ptr
,
max_fp8
,
force_pow_2_scales
,
epsilon
,
std
::
numeric_limits
<
float
>::
max
());
std
::
numeric_limits
<
float
>::
max
());
}
}
...
@@ -317,10 +352,21 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
...
@@ -317,10 +352,21 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT
(
output
.
data
.
dtype
,
DType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT
(
output
.
data
.
dtype
,
DType
,
max_fp8
=
Quantized_Limits
<
DType
>::
max_norm
;);
max_fp8
=
Quantized_Limits
<
DType
>::
max_norm
;);
// noop tensor for cuda graph
float
*
noop_ptr
=
nullptr
;
if
(
config_
!=
nullptr
)
{
const
QuantizationConfig
*
config_cpp
=
reinterpret_cast
<
const
QuantizationConfig
*>
(
config_
);
// extract noop tensor from quant_config_cpp if it's not null
const
NVTETensor
noop
=
config_cpp
?
config_cpp
->
noop_tensor
:
nullptr
;
noop_ptr
=
reinterpret_cast
<
float
*>
(
(
noop
!=
nullptr
?
convertNVTETensorCheck
(
noop
)
->
data
.
dptr
:
nullptr
));
}
// Update scale
// Update scale
compute_scale_from_amax_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
compute_scale_from_amax_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
reinterpret_cast
<
const
float
*>
(
output
.
amax
.
dptr
),
reinterpret_cast
<
const
float
*>
(
output
.
amax
.
dptr
),
reinterpret_cast
<
float
*>
(
output
.
scale
.
dptr
),
max_fp8
,
config
.
force_pow_2_scales
,
reinterpret_cast
<
float
*>
(
output
.
scale
.
dptr
),
max_fp8
,
config
.
force_pow_2_scales
,
config
.
amax_epsilon
);
config
.
amax_epsilon
,
noop_ptr
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
transformer_engine/common/recipe/fp8_block_scaling.cu
View file @
27ddce40
...
@@ -373,6 +373,7 @@ void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_
...
@@ -373,6 +373,7 @@ void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_
break
;
break
;
}
}
})
})
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
void
fp8_block_scaling_partial_cast
(
const
Tensor
inp
,
Tensor
out
,
const
Tensor
scale
,
size_t
h
,
void
fp8_block_scaling_partial_cast
(
const
Tensor
inp
,
Tensor
out
,
const
Tensor
scale
,
size_t
h
,
...
@@ -420,6 +421,7 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s
...
@@ -420,6 +421,7 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s
break
;
break
;
}
}
})))
})))
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
}
// namespace fp8_block_scaling_recipe
}
// namespace fp8_block_scaling_recipe
...
...
transformer_engine/common/swizzle/swizzle.cu
View file @
27ddce40
...
@@ -410,22 +410,25 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
...
@@ -410,22 +410,25 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
break
;
break
;
#else
#else
case
4
:
case
4
:
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
NVTE_CHECK_CUDA
(
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
break
;
case
2
:
case
2
:
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
NVTE_CHECK_CUDA
(
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
break
;
case
1
:
case
1
:
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
NVTE_CHECK_CUDA
(
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
...
@@ -435,6 +438,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
...
@@ -435,6 +438,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
NVTE_ERROR
(
"Not valid vec_load_size."
);
NVTE_ERROR
(
"Not valid vec_load_size."
);
break
;
break
;
}
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
if
(
input
->
has_columnwise_data
())
{
if
(
input
->
has_columnwise_data
())
{
int
vec_load_size
=
(
num_tiles_m
-
1
)
%
4
+
1
;
int
vec_load_size
=
(
num_tiles_m
-
1
)
%
4
+
1
;
...
@@ -472,24 +476,27 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
...
@@ -472,24 +476,27 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
break
;
break
;
#else
#else
case
4
:
case
4
:
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
NVTE_CHECK_CUDA
(
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
k
,
original_M
,
original_K
);
break
;
break
;
case
2
:
case
2
:
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
NVTE_CHECK_CUDA
(
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
k
,
original_M
,
original_K
);
break
;
break
;
case
1
:
case
1
:
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
NVTE_CHECK_CUDA
(
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
));
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
...
@@ -500,6 +507,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
...
@@ -500,6 +507,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
NVTE_ERROR
(
"Not valid vec_load_size."
);
NVTE_ERROR
(
"Not valid vec_load_size."
);
break
;
break
;
}
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
// 2D block scaling
// 2D block scaling
...
@@ -563,23 +571,23 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
...
@@ -563,23 +571,23 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
break
;
break
;
#else
#else
case
4
:
case
4
:
cudaFuncSetAttribute
(
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
multi_tensor_swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
multi_tensor_swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
)
)
;
multi_tensor_swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
multi_tensor_swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
break
;
case
2
:
case
2
:
cudaFuncSetAttribute
(
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
multi_tensor_swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
multi_tensor_swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
)
)
;
multi_tensor_swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
multi_tensor_swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
break
;
case
1
:
case
1
:
cudaFuncSetAttribute
(
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
multi_tensor_swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
multi_tensor_swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
)
)
;
multi_tensor_swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
multi_tensor_swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
break
;
...
@@ -614,23 +622,23 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
...
@@ -614,23 +622,23 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
break
;
break
;
#else
#else
case
4
:
case
4
:
cudaFuncSetAttribute
(
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
multi_tensor_swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
multi_tensor_swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
)
)
;
multi_tensor_swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
multi_tensor_swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
break
;
case
2
:
case
2
:
cudaFuncSetAttribute
(
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
multi_tensor_swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
multi_tensor_swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
)
)
;
multi_tensor_swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
multi_tensor_swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
break
;
case
1
:
case
1
:
cudaFuncSetAttribute
(
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
multi_tensor_swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
multi_tensor_swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
)
)
;
multi_tensor_swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
multi_tensor_swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
break
;
...
...
transformer_engine/common/transformer_engine.cpp
View file @
27ddce40
...
@@ -560,11 +560,11 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
...
@@ -560,11 +560,11 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
// Zero out tensor data if allocated
// Zero out tensor data if allocated
if
(
t
.
data
.
dptr
!=
nullptr
)
{
if
(
t
.
data
.
dptr
!=
nullptr
)
{
const
size_t
size_in_bytes
=
nvte_tensor_size_bytes
(
tensor
);
const
size_t
size_in_bytes
=
nvte_tensor_size_bytes
(
tensor
);
cudaMemsetAsync
(
t
.
data
.
dptr
,
0
,
size_in_bytes
,
stream
);
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
t
.
data
.
dptr
,
0
,
size_in_bytes
,
stream
)
)
;
}
}
// Set amax to 0 if allocated
// Set amax to 0 if allocated
if
(
t
.
amax
.
dptr
!=
nullptr
)
{
if
(
t
.
amax
.
dptr
!=
nullptr
)
{
cudaMemsetAsync
(
t
.
amax
.
dptr
,
0
,
sizeof
(
float
),
stream
);
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
t
.
amax
.
dptr
,
0
,
sizeof
(
float
),
stream
)
)
;
}
}
}
}
...
...
Prev
1
2
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment