Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
87e3e56e
Commit
87e3e56e
authored
Aug 27, 2025
by
yuguo
Browse files
Merge commit '
734bcedd
' of...
Merge commit '
734bcedd
' of
https://github.com/NVIDIA/TransformerEngine
parents
2f11bd2e
734bcedd
Changes
217
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
764 additions
and
397 deletions
+764
-397
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h
...ine/common/include/transformer_engine/comm_gemm_overlap.h
+19
-1
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+30
-0
transformer_engine/common/include/transformer_engine/multi_tensor.h
...r_engine/common/include/transformer_engine/multi_tensor.h
+15
-34
transformer_engine/common/include/transformer_engine/swizzle.h
...former_engine/common/include/transformer_engine/swizzle.h
+14
-0
transformer_engine/common/include/transformer_engine/transpose.h
...rmer_engine/common/include/transformer_engine/transpose.h
+8
-0
transformer_engine/common/libtransformer_engine.version
transformer_engine/common/libtransformer_engine.version
+1
-0
transformer_engine/common/multi_tensor/adam.cu
transformer_engine/common/multi_tensor/adam.cu
+28
-32
transformer_engine/common/multi_tensor/compute_scale.cu
transformer_engine/common/multi_tensor/compute_scale.cu
+8
-6
transformer_engine/common/multi_tensor/l2norm.cu
transformer_engine/common/multi_tensor/l2norm.cu
+8
-13
transformer_engine/common/multi_tensor/multi_tensor_apply.cuh
...sformer_engine/common/multi_tensor/multi_tensor_apply.cuh
+1
-50
transformer_engine/common/multi_tensor/scale.cu
transformer_engine/common/multi_tensor/scale.cu
+4
-5
transformer_engine/common/multi_tensor/sgd.cu
transformer_engine/common/multi_tensor/sgd.cu
+7
-8
transformer_engine/common/normalization/layernorm/ln_api.cpp
transformer_engine/common/normalization/layernorm/ln_api.cpp
+4
-0
transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh
..._engine/common/normalization/layernorm/ln_fwd_kernels.cuh
+26
-20
transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
...ormer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
+4
-0
transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh
...gine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh
+26
-20
transformer_engine/common/swizzle/swizzle.cu
transformer_engine/common/swizzle/swizzle.cu
+451
-180
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+2
-1
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
...e/common/transpose/quantize_transpose_square_blockwise.cu
+57
-12
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+51
-15
No files found.
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h
View file @
87e3e56e
...
@@ -36,7 +36,8 @@ enum class CommOverlapAlgo {
...
@@ -36,7 +36,8 @@ enum class CommOverlapAlgo {
SPLIT_PIPELINED_RS_P2P
=
4
,
SPLIT_PIPELINED_RS_P2P
=
4
,
ATOMIC_GEMM_RS
=
5
,
ATOMIC_GEMM_RS
=
5
,
ATOMIC_GEMM_AG_P2P
=
6
,
ATOMIC_GEMM_AG_P2P
=
6
,
ATOMIC_GEMM_RS_P2P
=
7
ATOMIC_GEMM_RS_P2P
=
7
,
EXTERNAL_BULK_OVERLAP_AG
=
8
,
};
};
class
CommOverlapCore
{
class
CommOverlapCore
{
...
@@ -133,6 +134,11 @@ class CommOverlapCore {
...
@@ -133,6 +134,11 @@ class CommOverlapCore {
cudaStream_t
stream_main
)
{
cudaStream_t
stream_main
)
{
NVTE_ERROR
(
"Operation is not implemented."
);
NVTE_ERROR
(
"Operation is not implemented."
);
}
}
virtual
void
bulk_overlap_external_ag
(
cudaStream_t
send_stream
,
cudaStream_t
recv_stream
,
cudaStream_t
stream_main
)
{
NVTE_ERROR
(
"Operation is not implemented."
);
}
};
// CommOverlapCore
};
// CommOverlapCore
class
CommOverlapBase
:
public
CommOverlapCore
{
class
CommOverlapBase
:
public
CommOverlapCore
{
...
@@ -200,6 +206,9 @@ class CommOverlapBase : public CommOverlapCore {
...
@@ -200,6 +206,9 @@ class CommOverlapBase : public CommOverlapCore {
TensorWrapper
&
workspace
,
bool
grad
,
bool
accumulate
,
TensorWrapper
&
workspace
,
bool
grad
,
bool
accumulate
,
bool
use_split_accumulator
,
TensorWrapper
&
rs_output
,
bool
use_split_accumulator
,
TensorWrapper
&
rs_output
,
cudaStream_t
stream_main
)
override
;
cudaStream_t
stream_main
)
override
;
void
bulk_overlap_external_ag
(
cudaStream_t
send_stream
,
cudaStream_t
recv_stream
,
cudaStream_t
stream_main
)
override
;
};
// CommOverlapBase
};
// CommOverlapBase
class
CommOverlapP2PBase
:
public
CommOverlapCore
{
class
CommOverlapP2PBase
:
public
CommOverlapCore
{
...
@@ -281,6 +290,15 @@ class CommOverlapP2PBase : public CommOverlapCore {
...
@@ -281,6 +290,15 @@ class CommOverlapP2PBase : public CommOverlapCore {
TensorWrapper
&
workspace
,
bool
grad
,
bool
accumulate
,
TensorWrapper
&
workspace
,
bool
grad
,
bool
accumulate
,
bool
use_split_accumulator
,
TensorWrapper
&
rs_output
,
bool
use_split_accumulator
,
TensorWrapper
&
rs_output
,
cudaStream_t
stream_main
)
override
;
cudaStream_t
stream_main
)
override
;
/*
** This function overlaps the AG for the current communicator object with the GEMM for the overlap_gemm object.
** The gemm for overlap_gemm is assumed to have been previously started.
*/
void
bulk_overlap_external_ag
(
cudaStream_t
send_stream
,
cudaStream_t
recv_stream
,
cudaStream_t
stream_main
)
override
{
NVTE_ERROR
(
"Operation not supported."
);
}
};
// CommOverlapP2PBase
};
// CommOverlapP2PBase
}
// namespace transformer_engine
}
// namespace transformer_engine
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
87e3e56e
...
@@ -44,6 +44,36 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
...
@@ -44,6 +44,36 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
NVTETensor
workspace
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
,
int
compute_stream_offset
=
0
);
int
math_sm_count
,
cudaStream_t
stream
,
bool
nvte_use_hipblaslt
=
0
,
bool
nvte_use_rocblas
=
0
,
int
compute_stream_offset
=
0
);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations,
* allowing for using a scaling factor for the GEMM result and the accumulation input
*
* Computes:
* - `D = alpha*AB` if both `bias` and `pre_gelu_out` are empty tensors
* - `D = alpha*AB + bias` if `pre_gelu_out` is empty and `bias` is not empty
* - `D = GELU(alpha*AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors
*
* \param[in] A The A matrix.
* \param[in] B The B matrix.
* \param[in,out] D Output matrix.
* \param[in] bias Bias tensor.
* \param[in,out] pre_gelu_out Output matrix before GELU activation.
* \param[in] transa Whether A matrix is transposed.
* \param[in] transb Whether B matrix is transposed.
* \param[in] grad Whether this operation is part of the
* gradient computation.
* \param[out] workspace Workspace tensor.
* \param[in] alpha Scaling factor applied to the result of the GEMM
* \param[in] beta Scaling factor applied to original value of D when
* accumulating into it. beta=0 means no accumulation.
* \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM.
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_cublas_gemm_scaled
(
const
NVTETensor
A
,
const
NVTETensor
B
,
NVTETensor
D
,
const
NVTETensor
bias
,
NVTETensor
pre_gelu_out
,
bool
transa
,
bool
transb
,
bool
grad
,
NVTETensor
workspace
,
float
alpha
,
float
beta
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
);
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
/*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters.
*
*
* \warning Cublas atomic gemm uses a beta API and is not tested for all use cases.
* \warning Cublas atomic gemm uses a beta API and is not tested for all use cases.
...
...
transformer_engine/common/include/transformer_engine/multi_tensor.h
View file @
87e3e56e
...
@@ -20,7 +20,6 @@ extern "C" {
...
@@ -20,7 +20,6 @@ extern "C" {
/*! \brief Computes L2 norm for a list of tensors.
/*! \brief Computes L2 norm for a list of tensors.
*
*
* \warning This API is **experimental** and subject to change.
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...
@@ -33,22 +32,19 @@ extern "C" {
...
@@ -33,22 +32,19 @@ extern "C" {
* \param[out] ret_per_tensor L2 norm for each tensor.
* \param[out] ret_per_tensor L2 norm for each tensor.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
void
nvte_multi_tensor_l2norm_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
void
nvte_multi_tensor_l2norm_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
output
,
NVTETensor
output_per_tensor
,
NVTETensor
ret
,
NVTETensor
output
,
NVTETensor
output_per_tensor
,
NVTETensor
ret
,
NVTETensor
ret_per_tensor
,
int
per_tensor
,
NVTETensor
ret_per_tensor
,
int
per_tensor
,
int
max_chunks_per_tensor
,
const
int
device_id
,
int
max_chunks_per_tensor
,
cudaStream_t
stream
);
cudaStream_t
stream
);
/*! \brief Computes L2 norm for a list of tensors after unscaling.
/*! \brief Computes L2 norm for a list of tensors after unscaling.
*
*
* Unscaling is only done for computing the L2 norm. The tensors themselves are not updated.
* Unscaling is only done for computing the L2 norm. The tensors themselves are not updated.
*
*
* \warning This API is **experimental** and subject to change.
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...
@@ -62,7 +58,6 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
...
@@ -62,7 +58,6 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] per_tensor Whether to calculate per tensor or cumulative norm.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
void
nvte_multi_tensor_unscale_l2norm_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
void
nvte_multi_tensor_unscale_l2norm_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
...
@@ -71,12 +66,11 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
...
@@ -71,12 +66,11 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor
output_per_tensor
,
NVTETensor
ret
,
NVTETensor
output_per_tensor
,
NVTETensor
ret
,
NVTETensor
ret_per_tensor
,
NVTETensor
inv_scale
,
NVTETensor
ret_per_tensor
,
NVTETensor
inv_scale
,
int
per_tensor
,
int
max_chunks_per_tensor
,
int
per_tensor
,
int
max_chunks_per_tensor
,
const
int
device_id
,
cudaStream_t
stream
);
cudaStream_t
stream
);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer.
/*! \brief Compute and apply gradient update to parameters for Adam optimizer.
*
*
* \warning This API is **experimental** and subject to change.
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...
@@ -91,7 +85,6 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
...
@@ -91,7 +85,6 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
void
nvte_multi_tensor_adam_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
void
nvte_multi_tensor_adam_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
...
@@ -99,13 +92,12 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
...
@@ -99,13 +92,12 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
device_id
,
cudaStream_t
stream
);
cudaStream_t
stream
);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* where the master parameters only store the remainder bits.
* where the master parameters only store the remainder bits.
*
*
* \warning This API is **experimental** and subject to change.
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...
@@ -120,20 +112,18 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
...
@@ -120,20 +112,18 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] mode Whether to use AdamW (L2 penalty applied to params).
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
void
nvte_multi_tensor_adam_param_remainder_cuda
(
void
nvte_multi_tensor_adam_param_remainder_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
size_t
num_tensors_per_list
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
device_id
,
cudaStream_t
stream
);
const
float
weight_decay
,
cudaStream_t
stream
);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* when model parameters are in Float8 precision.
* when model parameters are in Float8 precision.
*
*
* \warning This API is **experimental** and subject to change.
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...
@@ -149,7 +139,6 @@ void nvte_multi_tensor_adam_param_remainder_cuda(
...
@@ -149,7 +139,6 @@ void nvte_multi_tensor_adam_param_remainder_cuda(
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] fp8_dtype FP8 data type for model parameters.
* \param[in] fp8_dtype FP8 data type for model parameters.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
void
nvte_multi_tensor_adam_fp8_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
void
nvte_multi_tensor_adam_fp8_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
...
@@ -158,13 +147,12 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
...
@@ -158,13 +147,12 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
NVTEDType
fp8_dtype
,
const
float
weight_decay
,
const
NVTEDType
fp8_dtype
,
const
int
device_id
,
cudaStream_t
stream
);
cudaStream_t
stream
);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* with CUDA graph support and LR scheduling.
* with CUDA graph support and LR scheduling.
*
*
* \warning This API is **experimental** and subject to change.
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...
@@ -180,20 +168,18 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
...
@@ -180,20 +168,18 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
void
nvte_multi_tensor_adam_capturable_cuda
(
void
nvte_multi_tensor_adam_capturable_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
size_t
num_tensors_per_list
,
NVTETensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
NVTETensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
epsilon
,
NVTETensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
NVTETensor
inv_scale
,
const
int
device_id
,
cudaStream_t
stream
);
const
float
weight_decay
,
NVTETensor
inv_scale
,
cudaStream_t
stream
);
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
/*! \brief Compute and apply gradient update to parameters for Adam optimizer
* with CUDA graph support, LR scheduling, and FP32 master weights.
* with CUDA graph support, LR scheduling, and FP32 master weights.
*
*
* \warning This API is **experimental** and subject to change.
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...
@@ -209,19 +195,17 @@ void nvte_multi_tensor_adam_capturable_cuda(
...
@@ -209,19 +195,17 @@ void nvte_multi_tensor_adam_capturable_cuda(
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] bias_correction Whether to apply correction factor for moment estimates.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] weight_decay L2 penalty for weight decay.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] inv_scale Scalar for the unscaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
void
nvte_multi_tensor_adam_capturable_master_cuda
(
void
nvte_multi_tensor_adam_capturable_master_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
size_t
num_tensors_per_list
,
NVTETensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
NVTETensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
epsilon
,
NVTETensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
NVTETensor
inv_scale
,
const
int
device_id
,
cudaStream_t
stream
);
const
float
weight_decay
,
NVTETensor
inv_scale
,
cudaStream_t
stream
);
/*! \brief Compute and apply gradient update to parameters for SGD optimizer.
/*! \brief Compute and apply gradient update to parameters for SGD optimizer.
*
*
* \warning This API is **experimental** and subject to change.
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...
@@ -236,19 +220,17 @@ void nvte_multi_tensor_adam_capturable_master_cuda(
...
@@ -236,19 +220,17 @@ void nvte_multi_tensor_adam_capturable_master_cuda(
* \param[in] first_run Whether momentum buffers have been initialized.
* \param[in] first_run Whether momentum buffers have been initialized.
* \param[in] wd_after_momentum Whether to applied weight decay after momentum update.
* \param[in] wd_after_momentum Whether to applied weight decay after momentum update.
* \param[in] scale Scalar for the scaling operation.
* \param[in] scale Scalar for the scaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
void
nvte_multi_tensor_sgd_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
void
nvte_multi_tensor_sgd_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
int
nesterov
,
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
int
nesterov
,
int
first_run
,
int
wd_after_momentum
,
float
scale
,
int
first_run
,
int
wd_after_momentum
,
float
scale
,
const
int
device_id
,
cudaStream_t
stream
);
cudaStream_t
stream
);
/*! \brief Check overflow and scale a list of tensors.
/*! \brief Check overflow and scale a list of tensors.
*
*
* \warning This API is **experimental** and subject to change.
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...
@@ -256,17 +238,15 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor
...
@@ -256,17 +238,15 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] scale Scalar for the scaling operation.
* \param[in] scale Scalar for the scaling operation.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
void
nvte_multi_tensor_scale_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
void
nvte_multi_tensor_scale_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
float
scale
,
const
int
device_id
,
cudaStream_t
stream
);
float
scale
,
cudaStream_t
stream
);
/*! \brief Check overflow and scale a list of tensors.
/*! \brief Check overflow and scale a list of tensors.
*
*
* \warning This API is **experimental** and subject to change.
* \warning This API is **experimental** and subject to change.
* \warning Argument device_id is deprecated and will be removed in a future release.
*
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
* \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately.
...
@@ -276,13 +256,14 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens
...
@@ -276,13 +256,14 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens
* \param[in] max_fp8 Maximum representible value in underlying FP8 format.
* \param[in] max_fp8 Maximum representible value in underlying FP8 format.
* \param[in] force_pow_2_scales Ensure scaling factors are a power of 2.
* \param[in] force_pow_2_scales Ensure scaling factors are a power of 2.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] epsilon Term added to the denominator for numerical stability.
* \param[in] device_id [DEPRECATED] CUDA device ID for this operation.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
void
nvte_multi_tensor_compute_scale_and_scale_inv_cuda
(
void
nvte_multi_tensor_compute_scale_and_scale_inv_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensors_per_list
,
float
max_fp8
,
int
force_pow_2_scales
,
float
epsilon
,
const
size_t
num_tensor_lists
,
const
int
device_id
,
cudaStream_t
stream
);
const
size_t
num_tensors_per_list
,
float
max_fp8
,
int
force_pow_2_scales
,
float
epsilon
,
cudaStream_t
stream
);
#ifdef __cplusplus
#ifdef __cplusplus
}
// extern "C"
}
// extern "C"
...
...
transformer_engine/common/include/transformer_engine/swizzle.h
View file @
87e3e56e
...
@@ -30,6 +30,20 @@ extern "C" {
...
@@ -30,6 +30,20 @@ extern "C" {
*/
*/
void
nvte_swizzle_scaling_factors
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
void
nvte_swizzle_scaling_factors
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM
*
* \param[in] inputs Input tensors with non-swizzled scale_inv.
* \param[in,out] outputs Output tensors which hosts swizzled scale_inv.
* \param[in] stream CUDA stream used for the operation.
*
* Requirements:
* - scale_inv is stored in row-major.
* - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale.
* - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension.
*/
void
nvte_multi_tensor_swizzle_scaling_factors
(
const
NVTETensor
*
inputs
,
NVTETensor
*
outputs
,
const
size_t
num_tensors
,
cudaStream_t
stream
);
#ifdef __cplusplus
#ifdef __cplusplus
}
// extern "C"
}
// extern "C"
#endif
#endif
...
...
transformer_engine/common/include/transformer_engine/transpose.h
View file @
87e3e56e
...
@@ -318,6 +318,14 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_in
...
@@ -318,6 +318,14 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_in
void
nvte_dsreglu_cast_transpose
(
const
NVTETensor
input
,
const
NVTETensor
act_input
,
void
nvte_dsreglu_cast_transpose
(
const
NVTETensor
input
,
const
NVTETensor
act_input
,
NVTETensor
output
,
cudaStream_t
stream
);
NVTETensor
output
,
cudaStream_t
stream
);
/*! \brief Swap the first two tensor dimensions.
*
* \param[in] input Input tensor of shape [M, N, ...].
* \param[out] output Output tensor of shape [N, M, ...].
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_swap_first_dims
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
);
#ifdef __cplusplus
#ifdef __cplusplus
}
// extern "C"
}
// extern "C"
#endif
#endif
...
...
transformer_engine/common/libtransformer_engine.version
View file @
87e3e56e
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
transformer_engine::cuda::stream_priority_range*;
transformer_engine::cuda::stream_priority_range*;
transformer_engine::cuda::current_device*;
transformer_engine::cuda::current_device*;
transformer_engine::cuda_driver::get_symbol*;
transformer_engine::cuda_driver::get_symbol*;
transformer_engine::cuda_driver::ensure_context_exists*;
transformer_engine::ubuf_built_with_mpi*;
transformer_engine::ubuf_built_with_mpi*;
*transformer_engine::rtc*;
*transformer_engine::rtc*;
transformer_engine::nvte_cudnn_handle_init*;
transformer_engine::nvte_cudnn_handle_init*;
...
...
transformer_engine/common/multi_tensor/adam.cu
View file @
87e3e56e
...
@@ -577,7 +577,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
...
@@ -577,7 +577,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
const
float
lr
,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
device_id
,
cudaStream_t
stream
)
{
const
float
weight_decay
,
cudaStream_t
stream
)
{
// Handle bias correction mode
// Handle bias correction mode
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
if
(
bias_correction
==
1
)
{
if
(
bias_correction
==
1
)
{
...
@@ -644,9 +644,9 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
...
@@ -644,9 +644,9 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te
,
g_in_type
,
g_in_type_te
,
g_in_type
,
multi_tensor_apply
<
BLOCK_SIZE
,
4
>
((
int64_t
)
chunk_size
,
noop_flag
,
multi_tensor_apply
<
BLOCK_SIZE
,
4
>
((
int64_t
)
chunk_size
,
noop_flag
,
tensor_lists
,
tensor_lists
,
AdamFunctor
<
p_in_type
,
g_in_type
,
float
,
int64_t
>
(),
device_id
,
AdamFunctor
<
p_in_type
,
g_in_type
,
float
,
int64_t
>
(),
stream
,
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
(
adamMode_t
)
mode
,
weight_decay
);));
}
else
{
}
else
{
// g, p, m, v, p_master
// g, p, m, v, p_master
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
...
@@ -655,7 +655,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
...
@@ -655,7 +655,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te
,
g_in_type
,
g_in_type_te
,
g_in_type
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
>
(
multi_tensor_apply
<
BLOCK_SIZE
,
5
>
(
(
int64_t
)
chunk_size
,
noop_flag
,
tensor_lists
,
(
int64_t
)
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctorMaster
<
p_in_type
,
g_in_type
,
float
,
int64_t
>
(),
device_id
,
stream
,
AdamFunctorMaster
<
p_in_type
,
g_in_type
,
float
,
int64_t
>
(),
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
weight_decay
);));
}
}
...
@@ -667,7 +667,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
...
@@ -667,7 +667,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
g_in_type_te
,
g_in_type
,
g_in_type_te
,
g_in_type
,
multi_tensor_apply
<
BLOCK_SIZE
,
4
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
multi_tensor_apply
<
BLOCK_SIZE
,
4
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
p_in_type
,
g_in_type
,
float
,
int32_t
>
(),
device_id
,
AdamFunctor
<
p_in_type
,
g_in_type
,
float
,
int32_t
>
(),
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
}
else
{
}
else
{
...
@@ -678,9 +678,8 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
...
@@ -678,9 +678,8 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te
,
g_in_type
,
g_in_type_te
,
g_in_type
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctorMaster
<
p_in_type
,
g_in_type
,
float
,
int32_t
>
(),
AdamFunctorMaster
<
p_in_type
,
g_in_type
,
float
,
int32_t
>
(),
device_id
,
stream
,
beta1
,
beta2
,
bias_correction1
,
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
weight_decay
);));
}
}
}
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
NVTE_CHECK_CUDA
(
cudaGetLastError
());
...
@@ -691,7 +690,7 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
...
@@ -691,7 +690,7 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
device_id
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
// Handle bias correction mode
// Handle bias correction mode
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
if
(
bias_correction
==
1
)
{
if
(
bias_correction
==
1
)
{
...
@@ -733,7 +732,7 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
...
@@ -733,7 +732,7 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
g_in_type_te
,
g_in_type
,
g_in_type_te
,
g_in_type
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
>
((
int64_t
)
chunk_size
,
noop_flag
,
tensor_lists
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
>
((
int64_t
)
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctorMasterParamRemainder
<
g_in_type
,
float
,
int64_t
>
(),
device_id
,
AdamFunctorMasterParamRemainder
<
g_in_type
,
float
,
int64_t
>
(),
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
););
(
adamMode_t
)
mode
,
weight_decay
););
NVTE_CHECK_CUDA
(
cudaGetLastError
());
NVTE_CHECK_CUDA
(
cudaGetLastError
());
...
@@ -744,7 +743,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
...
@@ -744,7 +743,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
DType
fp8_dtype
,
const
float
weight_decay
,
const
DType
fp8_dtype
,
const
int
device_id
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
// Handle bias correction mode
// Handle bias correction mode
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
if
(
bias_correction
==
1
)
{
if
(
bias_correction
==
1
)
{
...
@@ -814,7 +813,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
...
@@ -814,7 +813,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te
,
g_in_type
,
g_in_type_te
,
g_in_type
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
,
true
>
(
multi_tensor_apply
<
BLOCK_SIZE
,
5
,
true
>
(
(
int64_t
)
chunk_size
,
noop_flag
,
tensor_lists
,
(
int64_t
)
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctorMaster
<
FP8_T
,
g_in_type
,
float
,
int64_t
>
(),
device_id
,
stream
,
beta1
,
AdamFunctorMaster
<
FP8_T
,
g_in_type
,
float
,
int64_t
>
(),
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
weight_decay
);));
}
else
{
}
else
{
...
@@ -824,9 +823,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
...
@@ -824,9 +823,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
g_in_type_te
,
g_in_type
,
g_in_type_te
,
g_in_type
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
,
true
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
,
true
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctorMaster
<
FP8_T
,
g_in_type
,
float
,
int32_t
>
(),
AdamFunctorMaster
<
FP8_T
,
g_in_type
,
float
,
int32_t
>
(),
device_id
,
stream
,
beta1
,
beta2
,
bias_correction1
,
stream
,
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);));
weight_decay
);));
}
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
...
@@ -836,7 +834,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
...
@@ -836,7 +834,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
Tensor
step
,
const
int
mode
,
const
int
bias_correction
,
Tensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
Tensor
inv_scale
,
const
float
weight_decay
,
Tensor
inv_scale
,
const
int
device_id
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
// Check tensor list sizes
// Check tensor list sizes
// 4 tensor lists: g, p, m, v
// 4 tensor lists: g, p, m, v
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
...
@@ -868,7 +866,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
...
@@ -868,7 +866,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
tensor_lists
[
0
][
0
]
->
dtype
(),
dtype
,
tensor_lists
[
0
][
0
]
->
dtype
(),
dtype
,
multi_tensor_apply
<
BLOCK_SIZE
,
4
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
multi_tensor_apply
<
BLOCK_SIZE
,
4
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
AdamCapturableFunctor
<
dtype
,
float
>
(),
device_id
,
stream
,
beta1
,
beta2
,
AdamCapturableFunctor
<
dtype
,
float
>
(),
stream
,
beta1
,
beta2
,
reinterpret_cast
<
int
*>
(
step
.
data
.
dptr
),
bias_correction
,
epsilon
,
reinterpret_cast
<
int
*>
(
step
.
data
.
dptr
),
bias_correction
,
epsilon
,
reinterpret_cast
<
float
*>
(
lr
.
data
.
dptr
),
(
adamMode_t
)
mode
,
weight_decay
,
reinterpret_cast
<
float
*>
(
lr
.
data
.
dptr
),
(
adamMode_t
)
mode
,
weight_decay
,
reinterpret_cast
<
float
*>
(
inv_scale
.
data
.
dptr
));)
reinterpret_cast
<
float
*>
(
inv_scale
.
data
.
dptr
));)
...
@@ -881,8 +879,7 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
...
@@ -881,8 +879,7 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
Tensor
lr
,
const
float
beta1
,
const
float
beta2
,
Tensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
Tensor
step
,
const
int
mode
,
const
float
epsilon
,
Tensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
bias_correction
,
const
float
weight_decay
,
Tensor
inv_scale
,
const
int
device_id
,
Tensor
inv_scale
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
// Check tensor list sizes
// Check tensor list sizes
// 4 tensor lists: g, p, m, v, p_master
// 4 tensor lists: g, p, m, v, p_master
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
...
@@ -917,7 +914,7 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
...
@@ -917,7 +914,7 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
tensor_lists
[
0
][
0
]
->
dtype
(),
dtype
,
tensor_lists
[
0
][
0
]
->
dtype
(),
dtype
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
multi_tensor_apply
<
BLOCK_SIZE
,
5
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
AdamCapturableMasterFunctor
<
dtype
,
float
>
(),
device_id
,
stream
,
beta1
,
AdamCapturableMasterFunctor
<
dtype
,
float
>
(),
stream
,
beta1
,
beta2
,
reinterpret_cast
<
int
*>
(
step
.
data
.
dptr
),
bias_correction
,
beta2
,
reinterpret_cast
<
int
*>
(
step
.
data
.
dptr
),
bias_correction
,
epsilon
,
reinterpret_cast
<
float
*>
(
lr
.
data
.
dptr
),
(
adamMode_t
)
mode
,
epsilon
,
reinterpret_cast
<
float
*>
(
lr
.
data
.
dptr
),
(
adamMode_t
)
mode
,
weight_decay
,
reinterpret_cast
<
float
*>
(
inv_scale
.
data
.
dptr
));)
weight_decay
,
reinterpret_cast
<
float
*>
(
inv_scale
.
data
.
dptr
));)
...
@@ -933,28 +930,28 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
...
@@ -933,28 +930,28 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
device_id
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_adam_cuda
);
NVTE_API_CALL
(
nvte_multi_tensor_adam_cuda
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
multi_tensor_adam
::
multi_tensor_adam_cuda
(
multi_tensor_adam
::
multi_tensor_adam_cuda
(
chunk_size
,
*
convertNVTETensorCheck
(
noop_flag
),
chunk_size
,
*
convertNVTETensorCheck
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
lr
,
beta1
,
beta2
,
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
device_id
,
stream
);
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
stream
);
}
}
void
nvte_multi_tensor_adam_param_remainder_cuda
(
void
nvte_multi_tensor_adam_param_remainder_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
size_t
num_tensors_per_list
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
device_id
,
cudaStream_t
stream
)
{
const
float
weight_decay
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_adam_param_remainder_cuda
);
NVTE_API_CALL
(
nvte_multi_tensor_adam_param_remainder_cuda
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
multi_tensor_adam
::
multi_tensor_adam_param_remainder_cuda
(
multi_tensor_adam
::
multi_tensor_adam_param_remainder_cuda
(
chunk_size
,
*
convertNVTETensorCheck
(
noop_flag
),
chunk_size
,
*
convertNVTETensorCheck
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
lr
,
beta1
,
beta2
,
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
device_id
,
stream
);
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
stream
);
}
}
void
nvte_multi_tensor_adam_fp8_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
void
nvte_multi_tensor_adam_fp8_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
...
@@ -963,22 +960,21 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
...
@@ -963,22 +960,21 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
const
NVTEDType
fp8_dtype
,
const
float
weight_decay
,
const
NVTEDType
fp8_dtype
,
const
int
device_id
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_adam_fp8_cuda
);
NVTE_API_CALL
(
nvte_multi_tensor_adam_fp8_cuda
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
multi_tensor_adam
::
multi_tensor_adam_fp8_cuda
(
multi_tensor_adam
::
multi_tensor_adam_fp8_cuda
(
chunk_size
,
*
convertNVTETensorCheck
(
noop_flag
),
chunk_size
,
*
convertNVTETensorCheck
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
lr
,
beta1
,
beta2
,
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
static_cast
<
DType
>
(
fp8_dtype
),
device_id
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
static_cast
<
DType
>
(
fp8_dtype
),
stream
);
stream
);
}
}
void
nvte_multi_tensor_adam_capturable_cuda
(
void
nvte_multi_tensor_adam_capturable_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
size_t
num_tensors_per_list
,
NVTETensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
NVTETensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
epsilon
,
NVTETensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
NVTETensor
inv_scale
,
const
int
device_id
,
cudaStream_t
stream
)
{
const
float
weight_decay
,
NVTETensor
inv_scale
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_adam_capturable_cuda
);
NVTE_API_CALL
(
nvte_multi_tensor_adam_capturable_cuda
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
...
@@ -986,14 +982,14 @@ void nvte_multi_tensor_adam_capturable_cuda(
...
@@ -986,14 +982,14 @@ void nvte_multi_tensor_adam_capturable_cuda(
chunk_size
,
*
convertNVTETensorCheck
(
noop_flag
),
chunk_size
,
*
convertNVTETensorCheck
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
*
convertNVTETensorCheck
(
lr
),
beta1
,
beta2
,
epsilon
,
*
convertNVTETensorCheck
(
step
),
mode
,
*
convertNVTETensorCheck
(
lr
),
beta1
,
beta2
,
epsilon
,
*
convertNVTETensorCheck
(
step
),
mode
,
bias_correction
,
weight_decay
,
*
convertNVTETensorCheck
(
inv_scale
),
device_id
,
stream
);
bias_correction
,
weight_decay
,
*
convertNVTETensorCheck
(
inv_scale
),
stream
);
}
}
void
nvte_multi_tensor_adam_capturable_master_cuda
(
void
nvte_multi_tensor_adam_capturable_master_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
size_t
num_tensors_per_list
,
NVTETensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
NVTETensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
epsilon
,
NVTETensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
NVTETensor
inv_scale
,
const
int
device_id
,
cudaStream_t
stream
)
{
const
float
weight_decay
,
NVTETensor
inv_scale
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_adam_capturable_master_cuda
);
NVTE_API_CALL
(
nvte_multi_tensor_adam_capturable_master_cuda
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
...
@@ -1001,5 +997,5 @@ void nvte_multi_tensor_adam_capturable_master_cuda(
...
@@ -1001,5 +997,5 @@ void nvte_multi_tensor_adam_capturable_master_cuda(
chunk_size
,
*
convertNVTETensorCheck
(
noop_flag
),
chunk_size
,
*
convertNVTETensorCheck
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
*
convertNVTETensorCheck
(
lr
),
beta1
,
beta2
,
epsilon
,
*
convertNVTETensorCheck
(
step
),
mode
,
*
convertNVTETensorCheck
(
lr
),
beta1
,
beta2
,
epsilon
,
*
convertNVTETensorCheck
(
step
),
mode
,
bias_correction
,
weight_decay
,
*
convertNVTETensorCheck
(
inv_scale
),
device_id
,
stream
);
bias_correction
,
weight_decay
,
*
convertNVTETensorCheck
(
inv_scale
),
stream
);
}
}
transformer_engine/common/multi_tensor/compute_scale.cu
View file @
87e3e56e
...
@@ -61,7 +61,7 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_f
...
@@ -61,7 +61,7 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_f
float
epsilon
,
const
int
device_id
,
float
epsilon
,
const
int
device_id
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
multi_tensor_apply
<
BLOCK_SIZE
,
3
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
multi_tensor_apply
<
BLOCK_SIZE
,
3
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
ComputeScaleAndScaleInvFunctor
(),
device_id
,
stream
,
max_fp8
,
ComputeScaleAndScaleInvFunctor
(),
stream
,
max_fp8
,
force_pow_2_scales
,
epsilon
);
force_pow_2_scales
,
epsilon
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
...
@@ -69,15 +69,17 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_f
...
@@ -69,15 +69,17 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_f
}
// namespace multi_tensor_compute_scale
}
// namespace multi_tensor_compute_scale
}
// namespace transformer_engine
}
// namespace transformer_engine
void
nvte_multi_tensor_compute_scale_and_scale_inv_cuda
(
void
nvte_multi_tensor_compute_scale_and_scale_inv_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensors_per_list
,
float
max_fp8
,
int
force_pow_2_scales
,
float
epsilon
,
const
size_t
num_tensor_lists
,
const
int
device_id
,
cudaStream_t
stream
)
{
const
size_t
num_tensors_per_list
,
float
max_fp8
,
int
force_pow_2_scales
,
float
epsilon
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_compute_scale_and_scale_inv_cuda
);
NVTE_API_CALL
(
nvte_multi_tensor_compute_scale_and_scale_inv_cuda
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
multi_tensor_compute_scale
::
multi_tensor_compute_scale_and_scale_inv_cuda
(
multi_tensor_compute_scale
::
multi_tensor_compute_scale_and_scale_inv_cuda
(
chunk_size
,
*
convertNVTETensorCheck
(
noop_flag
),
chunk_size
,
*
convertNVTETensorCheck
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
max_fp8
,
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
max_fp8
,
force_pow_2_scales
,
epsilon
,
device_id
,
stream
);
force_pow_2_scales
,
epsilon
,
stream
);
}
}
transformer_engine/common/multi_tensor/l2norm.cu
View file @
87e3e56e
...
@@ -401,12 +401,11 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
...
@@ -401,12 +401,11 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
void
multi_tensor_l2norm_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
void
multi_tensor_l2norm_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
Tensor
output
,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
Tensor
output
,
Tensor
output_per_tensor
,
Tensor
ret
,
Tensor
ret_per_tensor
,
Tensor
output_per_tensor
,
Tensor
ret
,
Tensor
ret_per_tensor
,
bool
per_tensor
,
int
max_chunks_per_tensor
,
const
int
device_id
,
bool
per_tensor
,
int
max_chunks_per_tensor
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
tensor_lists
[
0
][
0
]
->
dtype
(),
dtype
,
tensor_lists
[
0
][
0
]
->
dtype
(),
dtype
,
multi_tensor_apply
<
BLOCK_SIZE
,
1
>
(
multi_tensor_apply
<
BLOCK_SIZE
,
1
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
L2NormFunctor
<
dtype
>
(),
device_id
,
chunk_size
,
noop_flag
,
tensor_lists
,
L2NormFunctor
<
dtype
>
(),
stream
,
reinterpret_cast
<
float
*>
(
output
.
data
.
dptr
),
stream
,
reinterpret_cast
<
float
*>
(
output
.
data
.
dptr
),
per_tensor
?
reinterpret_cast
<
float
*>
(
output_per_tensor
.
data
.
dptr
)
:
nullptr
,
per_tensor
,
per_tensor
?
reinterpret_cast
<
float
*>
(
output_per_tensor
.
data
.
dptr
)
:
nullptr
,
per_tensor
,
max_chunks_per_tensor
);)
max_chunks_per_tensor
);)
...
@@ -416,7 +415,6 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
...
@@ -416,7 +415,6 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
// This involves one more small kernel launches, but will be negligible end to end.
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
// logic, but keeping it simple for now
const
OptionalCUDAGuard
device_guard
(
device_id
);
cleanup
<<<
per_tensor
?
tensor_lists
[
0
].
size
()
:
1
,
512
,
0
,
stream
>>>
(
cleanup
<<<
per_tensor
?
tensor_lists
[
0
].
size
()
:
1
,
512
,
0
,
stream
>>>
(
reinterpret_cast
<
float
*>
(
output
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
output
.
data
.
dptr
),
per_tensor
?
reinterpret_cast
<
float
*>
(
output_per_tensor
.
data
.
dptr
)
:
nullptr
,
per_tensor
?
reinterpret_cast
<
float
*>
(
output_per_tensor
.
data
.
dptr
)
:
nullptr
,
...
@@ -429,12 +427,11 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
...
@@ -429,12 +427,11 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
Tensor
output
,
Tensor
output_per_tensor
,
Tensor
ret
,
Tensor
output
,
Tensor
output_per_tensor
,
Tensor
ret
,
Tensor
ret_per_tensor
,
Tensor
inv_scale
,
bool
per_tensor
,
Tensor
ret_per_tensor
,
Tensor
inv_scale
,
bool
per_tensor
,
int
max_chunks_per_tensor
,
const
int
device_id
,
int
max_chunks_per_tensor
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
tensor_lists
[
0
][
0
]
->
dtype
(),
dtype
,
tensor_lists
[
0
][
0
]
->
dtype
(),
dtype
,
multi_tensor_apply
<
BLOCK_SIZE
,
1
>
(
multi_tensor_apply
<
BLOCK_SIZE
,
1
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
UnscaleL2NormFunctor
<
dtype
>
(),
device_id
,
chunk_size
,
noop_flag
,
tensor_lists
,
UnscaleL2NormFunctor
<
dtype
>
(),
stream
,
reinterpret_cast
<
float
*>
(
inv_scale
.
data
.
dptr
),
stream
,
reinterpret_cast
<
float
*>
(
inv_scale
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
output
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
output
.
data
.
dptr
),
per_tensor
?
reinterpret_cast
<
float
*>
(
output_per_tensor
.
data
.
dptr
)
:
nullptr
,
per_tensor
,
per_tensor
?
reinterpret_cast
<
float
*>
(
output_per_tensor
.
data
.
dptr
)
:
nullptr
,
per_tensor
,
...
@@ -445,7 +442,6 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
...
@@ -445,7 +442,6 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
// This involves one more small kernel launches, but will be negligible end to end.
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
// logic, but keeping it simple for now
const
OptionalCUDAGuard
device_guard
(
device_id
);
cleanup
<<<
per_tensor
?
tensor_lists
[
0
].
size
()
:
1
,
512
,
0
,
stream
>>>
(
cleanup
<<<
per_tensor
?
tensor_lists
[
0
].
size
()
:
1
,
512
,
0
,
stream
>>>
(
reinterpret_cast
<
float
*>
(
output
.
data
.
dptr
),
reinterpret_cast
<
float
*>
(
output
.
data
.
dptr
),
per_tensor
?
reinterpret_cast
<
float
*>
(
output_per_tensor
.
data
.
dptr
)
:
nullptr
,
per_tensor
?
reinterpret_cast
<
float
*>
(
output_per_tensor
.
data
.
dptr
)
:
nullptr
,
...
@@ -461,8 +457,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
...
@@ -461,8 +457,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
NVTETensor
output
,
NVTETensor
output_per_tensor
,
NVTETensor
ret
,
NVTETensor
output
,
NVTETensor
output_per_tensor
,
NVTETensor
ret
,
NVTETensor
ret_per_tensor
,
int
per_tensor
,
NVTETensor
ret_per_tensor
,
int
per_tensor
,
int
max_chunks_per_tensor
,
const
int
device_id
,
int
max_chunks_per_tensor
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_l2norm_cuda
);
NVTE_API_CALL
(
nvte_multi_tensor_l2norm_cuda
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
...
@@ -471,7 +466,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
...
@@ -471,7 +466,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
*
convertNVTETensorCheck
(
output
),
*
convertNVTETensorCheck
(
output_per_tensor
),
*
convertNVTETensorCheck
(
output
),
*
convertNVTETensorCheck
(
output_per_tensor
),
*
convertNVTETensorCheck
(
ret
),
*
convertNVTETensorCheck
(
ret_per_tensor
),
per_tensor
,
*
convertNVTETensorCheck
(
ret
),
*
convertNVTETensorCheck
(
ret_per_tensor
),
per_tensor
,
max_chunks_per_tensor
,
device_id
,
stream
);
max_chunks_per_tensor
,
stream
);
}
}
void
nvte_multi_tensor_unscale_l2norm_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
void
nvte_multi_tensor_unscale_l2norm_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
...
@@ -480,7 +475,7 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
...
@@ -480,7 +475,7 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor
output_per_tensor
,
NVTETensor
ret
,
NVTETensor
output_per_tensor
,
NVTETensor
ret
,
NVTETensor
ret_per_tensor
,
NVTETensor
inv_scale
,
NVTETensor
ret_per_tensor
,
NVTETensor
inv_scale
,
int
per_tensor
,
int
max_chunks_per_tensor
,
int
per_tensor
,
int
max_chunks_per_tensor
,
const
int
device_id
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_unscale_l2norm_cuda
);
NVTE_API_CALL
(
nvte_multi_tensor_unscale_l2norm_cuda
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
...
@@ -489,5 +484,5 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
...
@@ -489,5 +484,5 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
*
convertNVTETensorCheck
(
output
),
*
convertNVTETensorCheck
(
output_per_tensor
),
*
convertNVTETensorCheck
(
output
),
*
convertNVTETensorCheck
(
output_per_tensor
),
*
convertNVTETensorCheck
(
ret
),
*
convertNVTETensorCheck
(
ret_per_tensor
),
*
convertNVTETensorCheck
(
ret
),
*
convertNVTETensorCheck
(
ret_per_tensor
),
*
convertNVTETensorCheck
(
inv_scale
),
per_tensor
,
max_chunks_per_tensor
,
device_id
,
stream
);
*
convertNVTETensorCheck
(
inv_scale
),
per_tensor
,
max_chunks_per_tensor
,
stream
);
}
}
transformer_engine/common/multi_tensor/multi_tensor_apply.cuh
View file @
87e3e56e
...
@@ -14,53 +14,6 @@
...
@@ -14,53 +14,6 @@
// This header is the one-stop shop for all your multi-tensor apply needs.
// This header is the one-stop shop for all your multi-tensor apply needs.
// Change device if needed.
class
OptionalCUDAGuard
{
public:
explicit
OptionalCUDAGuard
(
int
new_device
)
{
if
(
new_device
<
0
)
return
;
int
current_device
;
NVTE_CHECK_CUDA
(
cudaGetDevice
(
&
current_device
));
if
(
new_device
!=
current_device
)
{
NVTE_CHECK_CUDA
(
cudaSetDevice
(
new_device
));
device_changed_
=
true
;
prev_device_
=
current_device
;
}
}
OptionalCUDAGuard
(
const
OptionalCUDAGuard
&
)
=
delete
;
OptionalCUDAGuard
&
operator
=
(
const
OptionalCUDAGuard
&
)
=
delete
;
OptionalCUDAGuard
(
OptionalCUDAGuard
&&
other
)
noexcept
:
prev_device_
(
other
.
prev_device_
),
device_changed_
(
other
.
device_changed_
)
{
other
.
device_changed_
=
false
;
}
OptionalCUDAGuard
&
operator
=
(
OptionalCUDAGuard
&&
other
)
noexcept
{
if
(
this
!=
&
other
)
{
if
(
device_changed_
)
{
cudaSetDevice
(
prev_device_
);
}
prev_device_
=
other
.
prev_device_
;
device_changed_
=
other
.
device_changed_
;
other
.
device_changed_
=
false
;
}
return
*
this
;
}
~
OptionalCUDAGuard
()
{
if
(
device_changed_
)
{
cudaSetDevice
(
prev_device_
);
}
}
private:
int
prev_device_
;
bool
device_changed_
=
false
;
};
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr
int
depth_to_max_tensors
[
6
]
=
{
110
,
64
,
48
,
36
,
30
,
24
};
constexpr
int
depth_to_max_tensors
[
6
]
=
{
110
,
64
,
48
,
36
,
30
,
24
};
constexpr
int
depth_to_max_blocks
[
6
]
=
{
320
,
320
,
320
,
320
,
320
,
320
};
constexpr
int
depth_to_max_blocks
[
6
]
=
{
320
,
320
,
320
,
320
,
320
,
320
};
...
@@ -94,7 +47,7 @@ template <int64_t block_size, int depth, bool USE_FP8 = false, typename T, typen
...
@@ -94,7 +47,7 @@ template <int64_t block_size, int depth, bool USE_FP8 = false, typename T, typen
void
multi_tensor_apply
(
int64_t
chunk_size
,
void
multi_tensor_apply
(
int64_t
chunk_size
,
const
transformer_engine
::
Tensor
&
noop_flag
,
const
transformer_engine
::
Tensor
&
noop_flag
,
std
::
vector
<
std
::
vector
<
transformer_engine
::
Tensor
*>>
tensor_lists
,
std
::
vector
<
std
::
vector
<
transformer_engine
::
Tensor
*>>
tensor_lists
,
T
callable
,
const
int
device_id
,
cudaStream_t
stream
,
ArgTypes
...
args
)
{
T
callable
,
cudaStream_t
stream
,
ArgTypes
...
args
)
{
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
const
size_t
num_tensors_per_list
=
tensor_lists
[
0
].
size
();
const
size_t
num_tensors_per_list
=
tensor_lists
[
0
].
size
();
...
@@ -108,8 +61,6 @@ void multi_tensor_apply(int64_t chunk_size,
...
@@ -108,8 +61,6 @@ void multi_tensor_apply(int64_t chunk_size,
TensorListMetadata
<
depth
,
USE_FP8
>
tl
;
TensorListMetadata
<
depth
,
USE_FP8
>
tl
;
const
OptionalCUDAGuard
device_guard
(
device_id
);
tl
.
start_tensor_this_launch
=
0
;
tl
.
start_tensor_this_launch
=
0
;
int
loc_block_info
=
0
;
int
loc_block_info
=
0
;
int
loc_tensor_info
=
0
;
int
loc_tensor_info
=
0
;
...
...
transformer_engine/common/multi_tensor/scale.cu
View file @
87e3e56e
...
@@ -104,13 +104,13 @@ struct ScaleFunctor {
...
@@ -104,13 +104,13 @@ struct ScaleFunctor {
void
multi_tensor_scale_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
void
multi_tensor_scale_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
float
scale
,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
float
scale
,
const
int
device_id
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
tensor_lists
[
0
][
0
]
->
dtype
(),
p_in_type
,
tensor_lists
[
0
][
0
]
->
dtype
(),
p_in_type
,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(
tensor_lists
[
1
][
0
]
->
dtype
(),
g_in_type
,
tensor_lists
[
1
][
0
]
->
dtype
(),
g_in_type
,
multi_tensor_apply
<
BLOCK_SIZE
,
2
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
multi_tensor_apply
<
BLOCK_SIZE
,
2
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
ScaleFunctor
<
p_in_type
,
g_in_type
>
(),
device_id
,
stream
,
scale
);))
ScaleFunctor
<
p_in_type
,
g_in_type
>
(),
stream
,
scale
);))
NVTE_CHECK_CUDA
(
cudaGetLastError
());
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
...
@@ -119,12 +119,11 @@ void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag,
...
@@ -119,12 +119,11 @@ void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag,
void
nvte_multi_tensor_scale_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
void
nvte_multi_tensor_scale_cuda
(
int
chunk_size
,
NVTETensor
noop_flag
,
NVTETensor
**
tensor_lists
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
float
scale
,
const
int
device_id
,
cudaStream_t
stream
)
{
float
scale
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_scale_cuda
);
NVTE_API_CALL
(
nvte_multi_tensor_scale_cuda
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
multi_tensor_scale
::
multi_tensor_scale_cuda
(
multi_tensor_scale
::
multi_tensor_scale_cuda
(
chunk_size
,
*
convertNVTETensorCheck
(
noop_flag
),
chunk_size
,
*
convertNVTETensorCheck
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
scale
,
device_id
,
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
scale
,
stream
);
stream
);
}
}
transformer_engine/common/multi_tensor/sgd.cu
View file @
87e3e56e
...
@@ -127,8 +127,7 @@ struct SGDFunctor {
...
@@ -127,8 +127,7 @@ struct SGDFunctor {
void
multi_tensor_sgd_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
void
multi_tensor_sgd_cuda
(
int
chunk_size
,
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
float
wd
,
float
momentum
,
std
::
vector
<
std
::
vector
<
Tensor
*>>
tensor_lists
,
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
bool
nesterov
,
bool
first_run
,
float
dampening
,
float
lr
,
bool
nesterov
,
bool
first_run
,
bool
wd_after_momentum
,
float
scale
,
const
int
device_id
,
bool
wd_after_momentum
,
float
scale
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
const
size_t
num_tensor_lists
=
tensor_lists
.
size
();
const
size_t
num_tensors_per_list
=
tensor_lists
[
0
].
size
();
const
size_t
num_tensors_per_list
=
tensor_lists
[
0
].
size
();
...
@@ -154,28 +153,28 @@ void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag,
...
@@ -154,28 +153,28 @@ void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag,
// Case 1. fp16, fp16, fp16, No
// Case 1. fp16, fp16, fp16, No
if
(
grad_type
==
DType
::
kFloat16
&&
weight_type
==
DType
::
kFloat16
&&
num_tensor_lists
==
3
)
{
if
(
grad_type
==
DType
::
kFloat16
&&
weight_type
==
DType
::
kFloat16
&&
num_tensor_lists
==
3
)
{
multi_tensor_apply
<
BLOCK_SIZE
,
3
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
multi_tensor_apply
<
BLOCK_SIZE
,
3
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
3
,
fp16
,
fp16
>
(),
device_id
,
stream
,
wd
,
momentum
,
dampening
,
SGDFunctor
<
3
,
fp16
,
fp16
>
(),
stream
,
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
}
// Case 2. fp32, fp32, fp32, No
// Case 2. fp32, fp32, fp32, No
else
if
(
grad_type
==
DType
::
kFloat32
&&
// NOLINT(*)
else
if
(
grad_type
==
DType
::
kFloat32
&&
// NOLINT(*)
weight_type
==
DType
::
kFloat32
&&
num_tensor_lists
==
3
)
{
weight_type
==
DType
::
kFloat32
&&
num_tensor_lists
==
3
)
{
multi_tensor_apply
<
BLOCK_SIZE
,
3
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
multi_tensor_apply
<
BLOCK_SIZE
,
3
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
3
,
float
,
float
>
(),
device_id
,
stream
,
wd
,
momentum
,
dampening
,
SGDFunctor
<
3
,
float
,
float
>
(),
stream
,
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
}
// Case 3. fp16, fp32, fp32, Yes
// Case 3. fp16, fp32, fp32, Yes
else
if
(
grad_type
==
DType
::
kFloat16
&&
// NOLINT(*)
else
if
(
grad_type
==
DType
::
kFloat16
&&
// NOLINT(*)
weight_type
==
DType
::
kFloat32
&&
num_tensor_lists
==
4
)
{
weight_type
==
DType
::
kFloat32
&&
num_tensor_lists
==
4
)
{
multi_tensor_apply
<
BLOCK_SIZE
,
4
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
multi_tensor_apply
<
BLOCK_SIZE
,
4
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
4
,
fp16
,
float
>
(),
device_id
,
stream
,
wd
,
momentum
,
dampening
,
SGDFunctor
<
4
,
fp16
,
float
>
(),
stream
,
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
}
// Case 4. fp32, fp32, fp32, Yes
// Case 4. fp32, fp32, fp32, Yes
else
if
(
grad_type
==
DType
::
kFloat32
&&
// NOLINT(*)
else
if
(
grad_type
==
DType
::
kFloat32
&&
// NOLINT(*)
weight_type
==
DType
::
kFloat32
&&
num_tensor_lists
==
4
)
{
weight_type
==
DType
::
kFloat32
&&
num_tensor_lists
==
4
)
{
multi_tensor_apply
<
BLOCK_SIZE
,
4
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
multi_tensor_apply
<
BLOCK_SIZE
,
4
>
(
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
4
,
float
,
float
>
(),
device_id
,
stream
,
wd
,
momentum
,
dampening
,
SGDFunctor
<
4
,
float
,
float
>
(),
stream
,
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
else
{
}
else
{
NVTE_ERROR
(
"Unsupported combination of weight and gradient types."
);
NVTE_ERROR
(
"Unsupported combination of weight and gradient types."
);
...
@@ -191,12 +190,12 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor
...
@@ -191,12 +190,12 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
const
size_t
num_tensor_lists
,
const
size_t
num_tensors_per_list
,
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
int
nesterov
,
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
int
nesterov
,
int
first_run
,
int
wd_after_momentum
,
float
scale
,
int
first_run
,
int
wd_after_momentum
,
float
scale
,
const
int
device_id
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_sgd_cuda
);
NVTE_API_CALL
(
nvte_multi_tensor_sgd_cuda
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
multi_tensor_sgd
::
multi_tensor_sgd_cuda
(
multi_tensor_sgd
::
multi_tensor_sgd_cuda
(
chunk_size
,
*
convertNVTETensorCheck
(
noop_flag
),
chunk_size
,
*
convertNVTETensorCheck
(
noop_flag
),
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
wd
,
momentum
,
convert_tensor_array
(
tensor_lists
,
num_tensor_lists
,
num_tensors_per_list
),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
,
device_id
,
stream
);
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
,
stream
);
}
}
transformer_engine/common/normalization/layernorm/ln_api.cpp
View file @
87e3e56e
...
@@ -72,6 +72,10 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
...
@@ -72,6 +72,10 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp_scaling
(
z
->
scaling_mode
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp_scaling
(
z
->
scaling_mode
);
#endif
#endif
if
(
!
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
z
->
amax
.
dptr
!=
nullptr
)
{
cudnn_backend
=
false
;
// cuDNN does not currently support amax output for non quantized output
}
bool
gamma_in_weight_dtype
=
false
;
bool
gamma_in_weight_dtype
=
false
;
if
(
cudnn_backend
)
{
if
(
cudnn_backend
)
{
// TODO: add check for GPU ARCH
// TODO: add check for GPU ARCH
...
...
transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh
View file @
87e3e56e
...
@@ -75,6 +75,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
...
@@ -75,6 +75,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
scale
=
*
reinterpret_cast
<
compute_t
*>
(
params
.
scale
);
scale
=
*
reinterpret_cast
<
compute_t
*>
(
params
.
scale
);
}
}
compute_t
amax
=
0
;
compute_t
amax
=
0
;
const
bool
requires_amax
=
params
.
amax
!=
nullptr
;
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
Ivec
x
[
LDGS
];
Ivec
x
[
LDGS
];
...
@@ -120,9 +121,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
...
@@ -120,9 +121,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
compute_t
b_ij
=
beta
[
it
].
data
.
elt
[
jt
];
compute_t
b_ij
=
beta
[
it
].
data
.
elt
[
jt
];
compute_t
temp_output
=
g_ij
*
y_ij
+
b_ij
;
compute_t
temp_output
=
g_ij
*
y_ij
+
b_ij
;
if
(
params
.
fp8_out
)
{
if
(
requires_amax
)
{
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
amax
>=
0
);
amax
=
fmaxf
(
amax
,
fabsf
(
temp_output
));
amax
=
fmaxf
(
amax
,
fabsf
(
temp_output
));
}
if
(
params
.
fp8_out
)
{
temp_output
=
temp_output
*
scale
;
temp_output
=
temp_output
*
scale
;
}
}
...
@@ -132,9 +135,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
...
@@ -132,9 +135,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
idx
+=
VEC_COLS_PER_LDG
;
idx
+=
VEC_COLS_PER_LDG
;
}
}
}
}
if
(
params
.
fp8_out
)
{
// Reduce amax over block
// Reduce amax over block
if
(
params
.
amax
!=
nullptr
)
{
if
(
requires_amax
)
{
amax
=
reduce_max
<
WARPS_M
*
WARPS_N
>
(
amax
,
warp
);
amax
=
reduce_max
<
WARPS_M
*
WARPS_N
>
(
amax
,
warp
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
static_assert
(
std
::
is_same
<
compute_t
,
float
>::
value
);
static_assert
(
std
::
is_same
<
compute_t
,
float
>::
value
);
...
@@ -142,6 +145,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
...
@@ -142,6 +145,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
}
}
}
}
if
(
params
.
fp8_out
)
{
// Update scale-inverse
// Update scale-inverse
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
&&
params
.
scale_inv
!=
nullptr
)
{
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
&&
params
.
scale_inv
!=
nullptr
)
{
reciprocal
<
compute_t
>
(
reinterpret_cast
<
compute_t
*>
(
params
.
scale_inv
),
scale
);
reciprocal
<
compute_t
>
(
reinterpret_cast
<
compute_t
*>
(
params
.
scale_inv
),
scale
);
...
@@ -211,6 +215,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
...
@@ -211,6 +215,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
scale
=
*
reinterpret_cast
<
compute_t
*>
(
params
.
scale
);
scale
=
*
reinterpret_cast
<
compute_t
*>
(
params
.
scale
);
}
}
compute_t
amax
=
0
;
compute_t
amax
=
0
;
const
bool
requires_amax
=
params
.
amax
!=
nullptr
;
for
(
int
cta_row
=
bidm
*
bdimm
;
cta_row
<
params
.
rows
;
cta_row
+=
gdimm
)
{
for
(
int
cta_row
=
bidm
*
bdimm
;
cta_row
<
params
.
rows
;
cta_row
+=
gdimm
)
{
const
int
row
=
cta_row
+
warp_m
;
const
int
row
=
cta_row
+
warp_m
;
...
@@ -279,17 +284,19 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
...
@@ -279,17 +284,19 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
}
}
// Apply fp8 factors
// Apply fp8 factors
if
(
params
.
fp8_out
)
{
if
(
params
.
fp8_out
||
requires_amax
)
{
#pragma unroll
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
if
(
col
+
jt
<
params
.
cols
)
{
if
(
col
+
jt
<
params
.
cols
)
{
compute_t
z_ij
=
z
.
data
.
elt
[
jt
];
compute_t
z_ij
=
z
.
data
.
elt
[
jt
];
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
amax
>=
0
);
amax
=
fmaxf
(
amax
,
fabsf
(
z_ij
));
amax
=
fmaxf
(
amax
,
fabsf
(
z_ij
));
if
(
params
.
fp8_out
)
{
z
.
data
.
elt
[
jt
]
=
z_ij
*
scale
;
z
.
data
.
elt
[
jt
]
=
z_ij
*
scale
;
}
}
}
}
}
}
}
// Store output
// Store output
Ovec
z_out
;
Ovec
z_out
;
...
@@ -298,10 +305,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
...
@@ -298,10 +305,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
}
}
}
}
// Finalize fp8 factors
if
(
params
.
fp8_out
)
{
// Reduce amax over block
// Reduce amax over block
if
(
params
.
amax
!=
nullptr
)
{
if
(
requires_amax
)
{
amax
=
reduce_max
<
WARPS_M
*
WARPS_N
>
(
amax
,
warp
);
amax
=
reduce_max
<
WARPS_M
*
WARPS_N
>
(
amax
,
warp
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
static_assert
(
std
::
is_same
<
compute_t
,
float
>::
value
);
static_assert
(
std
::
is_same
<
compute_t
,
float
>::
value
);
...
@@ -309,6 +314,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
...
@@ -309,6 +314,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
}
}
}
}
if
(
params
.
fp8_out
)
{
// Update scale-inverse
// Update scale-inverse
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
&&
params
.
scale_inv
!=
nullptr
)
{
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
&&
params
.
scale_inv
!=
nullptr
)
{
reciprocal
<
compute_t
>
(
reinterpret_cast
<
compute_t
*>
(
params
.
scale_inv
),
scale
);
reciprocal
<
compute_t
>
(
reinterpret_cast
<
compute_t
*>
(
params
.
scale_inv
),
scale
);
...
...
transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
View file @
87e3e56e
...
@@ -58,6 +58,10 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
...
@@ -58,6 +58,10 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp_scaling
(
z
->
scaling_mode
);
bool
cudnn_backend
=
use_cudnn_norm_fwd
()
||
is_mxfp_scaling
(
z
->
scaling_mode
);
#endif
#endif
if
(
!
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
z
->
amax
.
dptr
!=
nullptr
)
{
cudnn_backend
=
false
;
// cuDNN does not currently support amax output for non quantized output
}
bool
training
=
bool
training
=
is_delayed_tensor_scaling
(
z
->
scaling_mode
)
||
(
z
->
columnwise_data
).
dptr
!=
nullptr
;
is_delayed_tensor_scaling
(
z
->
scaling_mode
)
||
(
z
->
columnwise_data
).
dptr
!=
nullptr
;
...
...
transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh
View file @
87e3e56e
...
@@ -71,6 +71,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
...
@@ -71,6 +71,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
scale
=
*
reinterpret_cast
<
compute_t
*>
(
params
.
scale
);
scale
=
*
reinterpret_cast
<
compute_t
*>
(
params
.
scale
);
}
}
compute_t
amax
=
0
;
compute_t
amax
=
0
;
const
bool
requires_amax
=
params
.
amax
!=
nullptr
;
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
Ivec
x
[
LDGS
];
Ivec
x
[
LDGS
];
...
@@ -112,9 +113,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
...
@@ -112,9 +113,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
}
}
compute_t
temp_output
=
g_ij
*
y_ij
;
compute_t
temp_output
=
g_ij
*
y_ij
;
if
(
params
.
fp8_out
)
{
if
(
requires_amax
)
{
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
amax
>=
0
);
amax
=
fmaxf
(
amax
,
fabsf
(
temp_output
));
amax
=
fmaxf
(
amax
,
fabsf
(
temp_output
));
}
if
(
params
.
fp8_out
)
{
temp_output
=
temp_output
*
scale
;
temp_output
=
temp_output
*
scale
;
}
}
...
@@ -124,9 +127,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
...
@@ -124,9 +127,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
idx
+=
VEC_COLS_PER_LDG
;
idx
+=
VEC_COLS_PER_LDG
;
}
}
}
}
if
(
params
.
fp8_out
)
{
// Reduce amax over block
// Reduce amax over block
if
(
params
.
amax
!=
nullptr
)
{
if
(
requires_amax
)
{
amax
=
reduce_max
<
WARPS_M
*
WARPS_N
>
(
amax
,
warp
);
amax
=
reduce_max
<
WARPS_M
*
WARPS_N
>
(
amax
,
warp
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
static_assert
(
std
::
is_same
<
compute_t
,
float
>::
value
);
static_assert
(
std
::
is_same
<
compute_t
,
float
>::
value
);
...
@@ -134,6 +137,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
...
@@ -134,6 +137,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
}
}
}
}
if
(
params
.
fp8_out
)
{
// Update scale-inverse
// Update scale-inverse
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
&&
params
.
scale_inv
!=
nullptr
)
{
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
&&
params
.
scale_inv
!=
nullptr
)
{
reciprocal
<
compute_t
>
(
reinterpret_cast
<
compute_t
*>
(
params
.
scale_inv
),
scale
);
reciprocal
<
compute_t
>
(
reinterpret_cast
<
compute_t
*>
(
params
.
scale_inv
),
scale
);
...
@@ -201,6 +205,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
...
@@ -201,6 +205,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
scale
=
*
reinterpret_cast
<
compute_t
*>
(
params
.
scale
);
scale
=
*
reinterpret_cast
<
compute_t
*>
(
params
.
scale
);
}
}
compute_t
amax
=
0
;
compute_t
amax
=
0
;
const
bool
requires_amax
=
params
.
amax
!=
nullptr
;
for
(
int
cta_row
=
bidm
*
bdimm
;
cta_row
<
params
.
rows
;
cta_row
+=
gdimm
)
{
for
(
int
cta_row
=
bidm
*
bdimm
;
cta_row
<
params
.
rows
;
cta_row
+=
gdimm
)
{
const
int
row
=
cta_row
+
warp_m
;
const
int
row
=
cta_row
+
warp_m
;
...
@@ -254,17 +259,19 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
...
@@ -254,17 +259,19 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
}
}
// Apply fp8 factors
// Apply fp8 factors
if
(
params
.
fp8_out
)
{
if
(
params
.
fp8_out
||
requires_amax
)
{
#pragma unroll
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
if
(
col
+
jt
<
params
.
cols
)
{
if
(
col
+
jt
<
params
.
cols
)
{
compute_t
z_ij
=
z
.
data
.
elt
[
jt
];
compute_t
z_ij
=
z
.
data
.
elt
[
jt
];
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
amax
>=
0
);
amax
=
fmaxf
(
amax
,
fabsf
(
z_ij
));
amax
=
fmaxf
(
amax
,
fabsf
(
z_ij
));
if
(
params
.
fp8_out
)
{
z
.
data
.
elt
[
jt
]
=
z_ij
*
scale
;
z
.
data
.
elt
[
jt
]
=
z_ij
*
scale
;
}
}
}
}
}
}
}
// Store output
// Store output
Ovec
z_out
;
Ovec
z_out
;
...
@@ -273,10 +280,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
...
@@ -273,10 +280,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
}
}
}
}
// Finalize fp8 factors
if
(
params
.
fp8_out
)
{
// Reduce amax over block
// Reduce amax over block
if
(
params
.
amax
!=
nullptr
)
{
if
(
requires_amax
)
{
amax
=
reduce_max
<
WARPS_M
*
WARPS_N
>
(
amax
,
warp
);
amax
=
reduce_max
<
WARPS_M
*
WARPS_N
>
(
amax
,
warp
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
static_assert
(
std
::
is_same
<
compute_t
,
float
>::
value
);
static_assert
(
std
::
is_same
<
compute_t
,
float
>::
value
);
...
@@ -284,6 +289,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
...
@@ -284,6 +289,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
}
}
}
}
if
(
params
.
fp8_out
)
{
// Update scale-inverse
// Update scale-inverse
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
&&
params
.
scale_inv
!=
nullptr
)
{
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
&&
params
.
scale_inv
!=
nullptr
)
{
reciprocal
<
compute_t
>
(
reinterpret_cast
<
compute_t
*>
(
params
.
scale_inv
),
scale
);
reciprocal
<
compute_t
>
(
reinterpret_cast
<
compute_t
*>
(
params
.
scale_inv
),
scale
);
...
...
transformer_engine/common/swizzle/swizzle.cu
View file @
87e3e56e
...
@@ -15,15 +15,17 @@
...
@@ -15,15 +15,17 @@
#include "../util/logging.h"
#include "../util/logging.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transformer_engine.h"
namespace
transformer_engine
{
namespace
{
namespace
{
constexpr
int
TB_DIM
=
32
;
constexpr
__device__
__host__
int
MXFP8_BLOCK_SIZE
=
32
;
constexpr
int
NEW_SF_TILE_DIM_K
=
16
;
constexpr
__device__
__host__
int
TB_DIM
=
32
;
constexpr
int
N_SF_PER_TD_PER_TILE
=
4
;
constexpr
__device__
__host__
int
NEW_SF_TILE_DIM_K
=
16
;
constexpr
__device__
__host__
int
N_SF_PER_TD_PER_TILE
=
4
;
// output is in ~K-major interleaved blocks
// output is in ~K-major interleaved blocks
constexpr
int
NEW_SF_TILE_DIM_K_I32
=
NEW_SF_TILE_DIM_K
/
4
;
constexpr
__device__
__host__
int
NEW_SF_TILE_DIM_K_I32
=
NEW_SF_TILE_DIM_K
/
4
;
constexpr
int
NEW_SF_TILE_DIM_M_I32
=
32
;
constexpr
__device__
__host__
int
NEW_SF_TILE_DIM_M_I32
=
32
;
template
<
typename
LType
>
template
<
typename
LType
>
__device__
inline
void
regs_shuffle_with_bit_shifts
(
LType
*
regs_vec
)
{
__device__
inline
void
regs_shuffle_with_bit_shifts
(
LType
*
regs_vec
)
{
...
@@ -51,8 +53,11 @@ __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) {
...
@@ -51,8 +53,11 @@ __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) {
}
}
template
<
typename
LType
,
int
SF_TILE_DIM_M
,
int
SF_TILE_DIM_K
>
template
<
typename
LType
,
int
SF_TILE_DIM_M
,
int
SF_TILE_DIM_K
>
__global__
void
swizzle_col_scaling_kernel
(
const
void
*
input
,
void
*
output
,
const
int
M
,
__device__
void
swizzle_col_scaling_kernel_impl
(
const
void
*
input
,
void
*
output
,
const
int
M
,
const
int
K
)
{
const
int
K
,
const
int
original_M
,
const
int
original_K
,
const
int
bid_x
,
const
int
bid_y
,
const
int
grid_dim_x
,
const
int
grid_dim_y
)
{
constexpr
int
N_TILE_PER_TD
=
sizeof
(
LType
)
/
sizeof
(
int
);
constexpr
int
N_TILE_PER_TD
=
sizeof
(
LType
)
/
sizeof
(
int
);
constexpr
int
N_SF_PER_TD
=
N_TILE_PER_TD
*
N_SF_PER_TD_PER_TILE
;
constexpr
int
N_SF_PER_TD
=
N_TILE_PER_TD
*
N_SF_PER_TD_PER_TILE
;
constexpr
int
SF_TILE_SIZE_I32
=
SF_TILE_DIM_M
*
SF_TILE_DIM_K
/
4
;
constexpr
int
SF_TILE_SIZE_I32
=
SF_TILE_DIM_M
*
SF_TILE_DIM_K
/
4
;
...
@@ -66,21 +71,24 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons
...
@@ -66,21 +71,24 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons
int
m_tiles_in_tb
=
N_TILE_PER_TD
;
int
m_tiles_in_tb
=
N_TILE_PER_TD
;
int
k_tiles_in_tb
=
TB_DIM
;
int
k_tiles_in_tb
=
TB_DIM
;
if
(
b
lockIdx
.
x
==
grid
D
im
.
x
-
1
)
{
if
(
b
id_
x
==
grid
_d
im
_
x
-
1
)
{
k_tiles_in_tb
=
(
K_i32
/
SF_TILE_DIM_K_I32
-
1
)
%
k_tiles_in_tb
+
1
;
k_tiles_in_tb
=
(
K_i32
/
SF_TILE_DIM_K_I32
-
1
)
%
k_tiles_in_tb
+
1
;
}
}
if
(
b
lockIdx
.
y
==
grid
D
im
.
y
-
1
)
{
if
(
b
id_
y
==
grid
_d
im
_
y
-
1
)
{
m_tiles_in_tb
=
(
M_i32
/
SF_TILE_DIM_M_I32
-
1
)
%
m_tiles_in_tb
+
1
;
m_tiles_in_tb
=
(
M_i32
/
SF_TILE_DIM_M_I32
-
1
)
%
m_tiles_in_tb
+
1
;
}
}
const
int32_t
*
input_i32
=
reinterpret_cast
<
const
int32_t
*>
(
input
)
+
bool
padding_m
=
(
bid_y
==
grid_dim_y
-
1
)
&&
(
original_M
<
M
);
blockIdx
.
x
*
TB_DIM
*
SF_TILE_DIM_K_I32
*
M_i32
+
bool
padding_k
=
(
bid_x
==
grid_dim_x
-
1
)
&&
(
original_K
<
K
);
blockIdx
.
y
*
N_TILE_PER_TD
*
SF_TILE_DIM_M_I32
;
const
int
input_offset
=
bid_x
*
TB_DIM
*
SF_TILE_DIM_K_I32
*
M_i32
+
bid_y
*
N_TILE_PER_TD
*
SF_TILE_DIM_M_I32
;
const
int32_t
*
input_i32
=
reinterpret_cast
<
const
int32_t
*>
(
input
)
+
input_offset
;
int32_t
*
output_i32
[
N_TILE_PER_TD
];
int32_t
*
output_i32
[
N_TILE_PER_TD
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
m_tiles_in_tb
;
i
++
)
{
for
(
int
i
=
0
;
i
<
m_tiles_in_tb
;
i
++
)
{
output_i32
[
i
]
=
reinterpret_cast
<
int32_t
*>
(
output
)
+
b
lockIdx
.
x
*
TB_DIM
*
SF_TILE_SIZE_I32
+
output_i32
[
i
]
=
reinterpret_cast
<
int32_t
*>
(
output
)
+
b
id_
x
*
TB_DIM
*
SF_TILE_SIZE_I32
+
(
b
lockIdx
.
y
*
N_TILE_PER_TD
+
i
)
*
SF_TILE_DIM_M_I32
*
K_i32
;
(
b
id_
y
*
N_TILE_PER_TD
+
i
)
*
SF_TILE_DIM_M_I32
*
K_i32
;
}
}
extern
__shared__
int
slm
[];
extern
__shared__
int
slm
[];
...
@@ -90,85 +98,18 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons
...
@@ -90,85 +98,18 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons
threadIdx
.
y
<
k_tiles_in_tb
)
{
threadIdx
.
y
<
k_tiles_in_tb
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
N_SF_PER_TD_PER_TILE
;
i
++
)
{
for
(
int
i
=
0
;
i
<
N_SF_PER_TD_PER_TILE
;
i
++
)
{
regs_vec
[
i
]
=
__ldg
(
reinterpret_cast
<
const
LType
*>
(
const
int
thread_offset
=
input_i32
+
(
threadIdx
.
y
*
SF_TILE_DIM_K_I32
+
i
)
*
M_i32
+
threadIdx
.
x
*
N_TILE_PER_TD
));
(
threadIdx
.
y
*
SF_TILE_DIM_K_I32
+
i
)
*
M_i32
+
threadIdx
.
x
*
N_TILE_PER_TD
;
}
regs_vec
[
i
]
=
__ldg
(
reinterpret_cast
<
const
LType
*>
(
input_i32
+
thread_offset
));
// Pad zeros
// local shuffle
if
(
padding_m
||
padding_k
)
{
regs_shuffle_with_bit_shifts
(
regs_vec
);
for
(
int
j
=
0
;
j
<
N_TILE_PER_TD
*
sizeof
(
int
);
j
++
)
{
const
int
index
=
(
input_offset
+
thread_offset
)
*
sizeof
(
int
)
+
j
;
// store, regs -> shared
if
(
index
/
M
>=
original_K
||
index
%
M
>=
original_M
)
{
int
tM
=
threadIdx
.
x
*
N_SF_PER_TD
;
reinterpret_cast
<
uint8_t
*>
(
regs_vec
+
i
)[
j
]
=
0
;
int
*
slm_tile
=
slm
+
(
threadIdx
.
y
*
SF_TILE_SIZE_I32
+
tM
/
SF_TILE_DIM_M
*
k_tiles_in_tb
*
SF_TILE_SIZE_I32
);
#pragma unroll
for
(
int
i
=
0
;
i
<
N_SF_PER_TD
;
i
++
)
{
/* TODO rotate_i */
slm_tile
[(
tM
%
SF_TILE_DIM_M
)
/
NEW_SF_TILE_DIM_M_I32
+
((
tM
+
i
)
%
NEW_SF_TILE_DIM_M_I32
)
*
NEW_SF_TILE_DIM_K_I32
]
=
reinterpret_cast
<
int
*>
(
regs_vec
)[
i
];
}
}
__syncthreads
();
// store, shared -> global
int
linear_id
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
#pragma unroll
for
(
int
i
=
0
;
i
<
m_tiles_in_tb
;
i
++
)
{
__align__
(
16
)
int4
*
output_v4i
=
reinterpret_cast
<
int4
*>
(
output_i32
[
i
]);
__align__
(
16
)
int4
*
slm_v4i
=
reinterpret_cast
<
int4
*>
(
slm
+
i
*
k_tiles_in_tb
*
SF_TILE_SIZE_I32
);
#pragma unroll
for
(
int
j
=
linear_id
;
j
<
SF_TILE_SIZE_I32
*
k_tiles_in_tb
/
4
;
j
+=
blockDim
.
x
*
blockDim
.
y
)
{
output_v4i
[
j
]
=
slm_v4i
[
j
];
}
}
}
}
}
#ifdef __HIP_PLATFORM_AMD__
template
<
int
SF_TILE_DIM_M
,
int
SF_TILE_DIM_K
>
__global__
void
swizzle_col_scaling_kernel_int
(
const
void
*
input
,
void
*
output
,
const
int
M
,
const
int
K
)
{
constexpr
int
N_TILE_PER_TD
=
sizeof
(
int
)
/
sizeof
(
int
);
constexpr
int
N_SF_PER_TD
=
N_TILE_PER_TD
*
N_SF_PER_TD_PER_TILE
;
constexpr
int
SF_TILE_SIZE_I32
=
SF_TILE_DIM_M
*
SF_TILE_DIM_K
/
4
;
// input is in M-major
constexpr
int
SF_TILE_DIM_M_I32
=
SF_TILE_DIM_M
/
4
;
constexpr
int
SF_TILE_DIM_K_I32
=
SF_TILE_DIM_K
;
const
int
M_i32
=
M
/
4
;
const
int
K_i32
=
K
;
int
m_tiles_in_tb
=
N_TILE_PER_TD
;
int
k_tiles_in_tb
=
TB_DIM
;
if
(
blockIdx
.
x
==
gridDim
.
x
-
1
)
{
k_tiles_in_tb
=
(
K_i32
/
SF_TILE_DIM_K_I32
-
1
)
%
k_tiles_in_tb
+
1
;
}
if
(
blockIdx
.
y
==
gridDim
.
y
-
1
)
{
m_tiles_in_tb
=
(
M_i32
/
SF_TILE_DIM_M_I32
-
1
)
%
m_tiles_in_tb
+
1
;
}
}
const
int32_t
*
input_i32
=
reinterpret_cast
<
const
int32_t
*>
(
input
)
+
blockIdx
.
x
*
TB_DIM
*
SF_TILE_DIM_K_I32
*
M_i32
+
blockIdx
.
y
*
N_TILE_PER_TD
*
SF_TILE_DIM_M_I32
;
int32_t
*
output_i32
[
N_TILE_PER_TD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
m_tiles_in_tb
;
i
++
)
{
output_i32
[
i
]
=
reinterpret_cast
<
int32_t
*>
(
output
)
+
blockIdx
.
x
*
TB_DIM
*
SF_TILE_SIZE_I32
+
(
blockIdx
.
y
*
N_TILE_PER_TD
+
i
)
*
SF_TILE_DIM_M_I32
*
K_i32
;
}
extern
__shared__
int
slm
[];
// load, global -> regs
int
regs_vec
[
N_SF_PER_TD_PER_TILE
];
if
(
threadIdx
.
x
*
N_TILE_PER_TD
<
m_tiles_in_tb
*
SF_TILE_DIM_M_I32
&&
threadIdx
.
y
<
k_tiles_in_tb
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N_SF_PER_TD_PER_TILE
;
i
++
)
{
regs_vec
[
i
]
=
*
reinterpret_cast
<
const
int
*>
(
input_i32
+
(
threadIdx
.
y
*
SF_TILE_DIM_K_I32
+
i
)
*
M_i32
+
threadIdx
.
x
*
N_TILE_PER_TD
);
}
}
// local shuffle
// local shuffle
...
@@ -202,7 +143,14 @@ __global__ void swizzle_col_scaling_kernel_int(const void* input, void* output,
...
@@ -202,7 +143,14 @@ __global__ void swizzle_col_scaling_kernel_int(const void* input, void* output,
}
}
}
}
}
}
#endif
template
<
typename
LType
,
int
SF_TILE_DIM_M
,
int
SF_TILE_DIM_K
>
__global__
void
__launch_bounds__
(
TB_DIM
*
TB_DIM
)
swizzle_col_scaling_kernel
(
const
void
*
input
,
void
*
output
,
const
int
M
,
const
int
K
,
const
int
original_M
,
const
int
original_K
)
{
swizzle_col_scaling_kernel_impl
<
LType
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
(
input
,
output
,
M
,
K
,
original_M
,
original_K
,
blockIdx
.
x
,
blockIdx
.
y
,
gridDim
.
x
,
gridDim
.
y
);
}
template
<
typename
LType
>
template
<
typename
LType
>
__device__
inline
void
regs_shuffle
(
LType
*
regs_vec
)
{
__device__
inline
void
regs_shuffle
(
LType
*
regs_vec
)
{
...
@@ -221,8 +169,11 @@ __device__ inline void regs_shuffle(LType* regs_vec) {
...
@@ -221,8 +169,11 @@ __device__ inline void regs_shuffle(LType* regs_vec) {
}
}
template
<
typename
LType
,
int
SF_TILE_DIM_M
,
int
SF_TILE_DIM_K
>
template
<
typename
LType
,
int
SF_TILE_DIM_M
,
int
SF_TILE_DIM_K
>
__global__
void
swizzle_row_scaling_kernel
(
const
void
*
input
,
void
*
output
,
const
int
M
,
__device__
void
swizzle_row_scaling_kernel_impl
(
const
void
*
input
,
void
*
output
,
const
int
M
,
const
int
K
)
{
const
int
K
,
const
int
original_M
,
const
int
original_K
,
const
int
bid_x
,
const
int
bid_y
,
const
int
grid_dim_x
,
const
int
grid_dim_y
)
{
constexpr
int
N_TILE_PER_TD
=
sizeof
(
LType
)
/
sizeof
(
int
);
constexpr
int
N_TILE_PER_TD
=
sizeof
(
LType
)
/
sizeof
(
int
);
constexpr
int
N_TILES_IN_TB
=
TB_DIM
*
N_TILE_PER_TD
;
constexpr
int
N_TILES_IN_TB
=
TB_DIM
*
N_TILE_PER_TD
;
...
@@ -232,14 +183,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
...
@@ -232,14 +183,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
int
n_tiles_in_tb
=
N_TILES_IN_TB
;
int
n_tiles_in_tb
=
N_TILES_IN_TB
;
const
int
K_i32
=
K
/
4
;
const
int
K_i32
=
K
/
4
;
if
(
b
lockIdx
.
x
==
grid
D
im
.
x
-
1
)
{
if
(
b
id_
x
==
grid
_d
im
_
x
-
1
)
{
n_tiles_in_tb
=
(
K_i32
-
1
)
%
N_TILES_IN_TB
+
1
;
n_tiles_in_tb
=
(
K_i32
-
1
)
%
N_TILES_IN_TB
+
1
;
}
}
const
int
*
input_i32
=
reinterpret_cast
<
const
int
*>
(
input
)
+
bool
padding_m
=
(
bid_y
==
grid_dim_y
-
1
)
&&
(
original_M
<
M
);
blockIdx
.
y
*
SF_TILE_DIM_M_I32
*
K_i32
+
blockIdx
.
x
*
N_TILES_IN_TB
;
bool
padding_k
=
(
bid_x
==
grid_dim_x
-
1
)
&&
(
original_K
<
K
);
int
*
output_i32
=
reinterpret_cast
<
int
*>
(
output
)
+
blockIdx
.
y
*
SF_TILE_DIM_M_I32
*
K_i32
+
blockIdx
.
x
*
N_TILES_IN_TB
*
SF_TILE_SIZE_I32
;
const
int
input_offset
=
bid_y
*
SF_TILE_DIM_M_I32
*
K_i32
+
bid_x
*
N_TILES_IN_TB
;
const
int
*
input_i32
=
reinterpret_cast
<
const
int
*>
(
input
)
+
input_offset
;
int
*
output_i32
=
reinterpret_cast
<
int
*>
(
output
)
+
bid_y
*
SF_TILE_DIM_M_I32
*
K_i32
+
bid_x
*
N_TILES_IN_TB
*
SF_TILE_SIZE_I32
;
extern
__shared__
int4
slm_v4i
[];
extern
__shared__
int4
slm_v4i
[];
...
@@ -248,8 +202,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
...
@@ -248,8 +202,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
if
(
threadIdx
.
x
*
N_TILE_PER_TD
<
n_tiles_in_tb
)
{
if
(
threadIdx
.
x
*
N_TILE_PER_TD
<
n_tiles_in_tb
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
N_SF_PER_TD_PER_TILE
;
i
++
)
{
for
(
int
i
=
0
;
i
<
N_SF_PER_TD_PER_TILE
;
i
++
)
{
regs_vec
[
i
]
=
__ldg
(
reinterpret_cast
<
const
LType
*>
(
const
int
thread_offset
=
(
i
*
TB_DIM
+
threadIdx
.
y
)
*
K_i32
+
threadIdx
.
x
*
N_TILE_PER_TD
;
input_i32
+
(
i
*
TB_DIM
+
threadIdx
.
y
)
*
K_i32
+
threadIdx
.
x
*
N_TILE_PER_TD
));
regs_vec
[
i
]
=
__ldg
(
reinterpret_cast
<
const
LType
*>
(
input_i32
+
thread_offset
));
if
(
padding_m
||
padding_k
)
{
// Pad zeros
for
(
int
j
=
0
;
j
<
N_TILE_PER_TD
*
sizeof
(
int
);
j
++
)
{
const
int
index
=
(
input_offset
+
thread_offset
)
*
sizeof
(
int
)
+
j
;
if
(
index
/
K
>=
original_M
||
index
%
K
>=
original_K
)
{
reinterpret_cast
<
uint8_t
*>
(
regs_vec
+
i
)[
j
]
=
0
;
}
}
}
}
}
// shuffle regs
// shuffle regs
...
@@ -274,64 +237,99 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
...
@@ -274,64 +237,99 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons
}
}
}
}
#ifdef __HIP_PLATFORM_AMD__
template
<
typename
LType
,
int
SF_TILE_DIM_M
,
int
SF_TILE_DIM_K
>
template
<
int
SF_TILE_DIM_M
,
int
SF_TILE_DIM_K
>
__global__
void
__launch_bounds__
(
TB_DIM
*
TB_DIM
)
__global__
void
swizzle_row_scaling_kernel_int
(
const
void
*
input
,
void
*
output
,
const
int
M
,
swizzle_row_scaling_kernel
(
const
void
*
input
,
void
*
output
,
const
int
M
,
const
int
K
,
const
int
K
)
{
const
int
original_M
,
const
int
original_K
)
{
constexpr
int
N_TILE_PER_TD
=
sizeof
(
int
)
/
sizeof
(
int
);
swizzle_row_scaling_kernel_impl
<
LType
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
(
constexpr
int
N_TILES_IN_TB
=
TB_DIM
*
N_TILE_PER_TD
;
input
,
output
,
M
,
K
,
original_M
,
original_K
,
blockIdx
.
x
,
blockIdx
.
y
,
gridDim
.
x
,
gridDim
.
y
);
}
// input is in K-major
constexpr
int
kMaxTensorsPerKernel
=
64
;
// Args must be <4 KB
constexpr
int
SF_TILE_SIZE_I32
=
SF_TILE_DIM_M
*
SF_TILE_DIM_K
/
4
;
struct
MultiSwizzleArgs
{
constexpr
int
SF_TILE_DIM_M_I32
=
SF_TILE_DIM_M
;
// (input) Data buffers for input scaling factors
void
*
input_list
[
kMaxTensorsPerKernel
];
// (output) Data buffers for swizzled scaling factors
void
*
output_list
[
kMaxTensorsPerKernel
];
// Input scaling factor m
int
m_list
[
kMaxTensorsPerKernel
];
// Input scaling factor k
int
k_list
[
kMaxTensorsPerKernel
];
// Input scaling factor m before padding
int
original_m_list
[
kMaxTensorsPerKernel
];
// Input scaling factor k before padding
int
original_k_list
[
kMaxTensorsPerKernel
];
// Prefix sum (with leading zero) of CUDA blocks needed for each
// tensor
int
block_range
[
kMaxTensorsPerKernel
+
1
];
// Number of tensors being processed by kernel
int
num_tensors
;
};
int
n_tiles_in_tb
=
N_TILES_IN_TB
;
template
<
typename
LType
,
int
SF_TILE_DIM_M
,
int
SF_TILE_DIM_K
>
const
int
K_i32
=
K
/
4
;
__global__
void
multi_tensor_swizzle_row_scaling_kernel
(
MultiSwizzleArgs
kernel_args
)
{
if
(
blockIdx
.
x
==
gridDim
.
x
-
1
)
{
// Find tensor corresponding to block
n_tiles_in_tb
=
(
K_i32
-
1
)
%
N_TILES_IN_TB
+
1
;
const
int
bid
=
blockIdx
.
x
;
}
int
tensor_id
=
0
;
while
(
kernel_args
.
block_range
[
tensor_id
+
1
]
<=
bid
)
{
++
tensor_id
;
}
// Get args corresponding to block
const
void
*
input
=
kernel_args
.
input_list
[
tensor_id
];
void
*
output
=
kernel_args
.
output_list
[
tensor_id
];
const
int
M
=
kernel_args
.
m_list
[
tensor_id
];
const
int
K
=
kernel_args
.
k_list
[
tensor_id
];
const
int
original_M
=
kernel_args
.
original_m_list
[
tensor_id
];
const
int
original_K
=
kernel_args
.
original_k_list
[
tensor_id
];
const
int
*
input_i32
=
reinterpret_cast
<
const
int
*>
(
input
)
+
constexpr
int
N_TILE_PER_TD
=
sizeof
(
LType
)
/
sizeof
(
int
);
blockIdx
.
y
*
SF_TILE_DIM_M_I32
*
K_i32
+
blockIdx
.
x
*
N_TILES_IN_TB
;
constexpr
int
N_TILES_IN_TB
=
TB_DIM
*
N_TILE_PER_TD
;
int
*
output_i32
=
reinterpret_cast
<
int
*>
(
output
)
+
blockIdx
.
y
*
SF_TILE_DIM_M_I32
*
K_i32
+
blockIdx
.
x
*
N_TILES_IN_TB
*
SF_TILE_SIZE_I32
;
extern
__shared__
int4
slm_v4i
[];
// Get block index in grid. Emulate 2D grid.
const
int
num_tiles_k
=
K
/
SF_TILE_DIM_K
;
const
int
num_tiles_m
=
M
/
SF_TILE_DIM_M
;
const
int
grid_dim_x
=
DIVUP
(
num_tiles_k
,
N_TILES_IN_TB
);
const
int
grid_dim_y
=
num_tiles_m
;
const
int
bid_x
=
(
bid
-
kernel_args
.
block_range
[
tensor_id
])
/
grid_dim_y
;
const
int
bid_y
=
(
bid
-
kernel_args
.
block_range
[
tensor_id
])
%
grid_dim_y
;
// load, global -> regs
swizzle_row_scaling_kernel_impl
<
LType
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
(
int
regs_vec
[
N_SF_PER_TD_PER_TILE
];
input
,
output
,
M
,
K
,
original_M
,
original_K
,
bid_x
,
bid_y
,
grid_dim_x
,
grid_dim_y
);
if
(
threadIdx
.
x
*
N_TILE_PER_TD
<
n_tiles_in_tb
)
{
}
#pragma unroll
for
(
int
i
=
0
;
i
<
N_SF_PER_TD_PER_TILE
;
i
++
)
{
regs_vec
[
i
]
=
*
reinterpret_cast
<
const
int
*>
(
input_i32
+
(
i
*
TB_DIM
+
threadIdx
.
y
)
*
K_i32
+
threadIdx
.
x
*
N_TILE_PER_TD
);
}
// shuffle regs
template
<
typename
LType
,
int
SF_TILE_DIM_M
,
int
SF_TILE_DIM_K
>
regs_shuffle
<
int
>
(
regs_vec
);
__global__
void
multi_tensor_swizzle_col_scaling_kernel
(
MultiSwizzleArgs
kernel_args
)
{
// Find tensor corresponding to block
const
int
bid
=
blockIdx
.
x
;
int
tensor_id
=
0
;
while
(
kernel_args
.
block_range
[
tensor_id
+
1
]
<=
bid
)
{
++
tensor_id
;
}
// Get args corresponding to block
const
void
*
input
=
kernel_args
.
input_list
[
tensor_id
];
void
*
output
=
kernel_args
.
output_list
[
tensor_id
];
const
int
M
=
kernel_args
.
m_list
[
tensor_id
];
const
int
K
=
kernel_args
.
k_list
[
tensor_id
];
const
int
original_M
=
kernel_args
.
original_m_list
[
tensor_id
];
const
int
original_K
=
kernel_args
.
original_k_list
[
tensor_id
];
// store, regs -> shared
constexpr
int
N_TILE_PER_TD
=
sizeof
(
LType
)
/
sizeof
(
int
);
#pragma unroll
constexpr
int
N_SF_PER_TD
=
N_TILE_PER_TD
*
N_SF_PER_TD_PER_TILE
;
for
(
int
i
=
0
;
i
<
N_TILE_PER_TD
;
i
++
)
{
constexpr
int
SF_TILE_SIZE_I32
=
SF_TILE_DIM_M
*
SF_TILE_DIM_K
/
4
;
/* TODO rotate i */
slm_v4i
[(
threadIdx
.
x
*
N_TILE_PER_TD
+
i
)
*
SF_TILE_SIZE_I32
/
4
+
threadIdx
.
y
]
=
reinterpret_cast
<
int4
*>
(
regs_vec
)[
i
];
}
}
__syncthreads
();
// store, shared -> global
// Get block index in grid. Emulate 2D grid.
int
linear_id
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
num_tiles_k
=
K
/
SF_TILE_DIM_K
;
__align__
(
16
)
int4
*
output_v4i
=
reinterpret_cast
<
int4
*>
(
output_i32
);
const
int
num_tiles_m
=
M
/
SF_TILE_DIM_M
;
#pragma unroll
const
int
grid_dim_x
=
DIVUP
(
num_tiles_k
,
TB_DIM
);
for
(
int
i
=
linear_id
;
i
<
SF_TILE_SIZE_I32
*
n_tiles_in_tb
/
4
;
i
+=
blockDim
.
x
*
blockDim
.
y
)
{
const
int
grid_dim_y
=
DIVUP
(
num_tiles_m
,
N_TILE_PER_TD
);
output_v4i
[
i
]
=
slm_v4i
[
i
];
const
int
bid_x
=
(
bid
-
kernel_args
.
block_range
[
tensor_id
])
/
grid_dim_y
;
}
const
int
bid_y
=
(
bid
-
kernel_args
.
block_range
[
tensor_id
])
%
grid_dim_y
;
swizzle_col_scaling_kernel_impl
<
LType
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
(
input
,
output
,
M
,
K
,
original_M
,
original_K
,
bid_x
,
bid_y
,
grid_dim_x
,
grid_dim_y
);
}
}
#endif
}
// namespace
namespace
transformer_engine
{
}
//
namespace
void
swizzle_scaling_factors
(
const
Tensor
*
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
void
swizzle_scaling_factors
(
const
Tensor
*
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
if
(
!
is_fp8_dtype
(
input
->
dtype
())
||
is_delayed_tensor_scaling
(
input
->
scaling_mode
))
{
if
(
!
is_fp8_dtype
(
input
->
dtype
())
||
is_delayed_tensor_scaling
(
input
->
scaling_mode
))
{
...
@@ -385,50 +383,52 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
...
@@ -385,50 +383,52 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int
n_tiles_in_tb
=
TB_DIM
*
vec_load_size
;
int
n_tiles_in_tb
=
TB_DIM
*
vec_load_size
;
dim3
num_blocks
(
DIVUP
(
num_tiles_k
,
n_tiles_in_tb
),
num_tiles_m
);
dim3
num_blocks
(
DIVUP
(
num_tiles_k
,
n_tiles_in_tb
),
num_tiles_m
);
int
slm_size
=
n_tiles_in_tb
*
SF_TILE_DIM_M
*
SF_TILE_DIM_K
*
sizeof
(
int8_t
);
int
slm_size
=
n_tiles_in_tb
*
SF_TILE_DIM_M
*
SF_TILE_DIM_K
*
sizeof
(
int8_t
);
const
int
original_M
=
input
->
flat_first_dim
();
const
int
original_K
=
input
->
flat_last_dim
()
/
MXFP8_BLOCK_SIZE
;
switch
(
vec_load_size
)
{
switch
(
vec_load_size
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
case
4
:
case
4
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncSetAttribute
((
const
void
*
)
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
dptr
,
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
output
->
scale_inv
.
dptr
,
m
,
k
);
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
break
;
case
2
:
case
2
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncSetAttribute
((
const
void
*
)
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
dptr
,
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
output
->
scale_inv
.
dptr
,
m
,
k
);
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
break
;
case
1
:
case
1
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_row_scaling_kernel
_
int
<
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncSetAttribute
((
const
void
*
)
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel
_
int
<
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
dptr
,
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
output
->
scale_inv
.
dptr
,
m
,
k
);
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
break
;
#else
#else
case
4
:
case
4
:
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
dptr
,
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
output
->
scale_inv
.
dptr
,
m
,
k
);
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
break
;
case
2
:
case
2
:
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
dptr
,
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
output
->
scale_inv
.
dptr
,
m
,
k
);
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
break
;
case
1
:
case
1
:
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncSetAttribute
(
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
scale_inv
.
dptr
,
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
output
->
scale_inv
.
dptr
,
m
,
k
);
input
->
scale_inv
.
dptr
,
output
->
scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
break
;
#endif
#endif
default:
default:
...
@@ -442,50 +442,58 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
...
@@ -442,50 +442,58 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int
n_tiles_in_tb
=
TB_DIM
*
vec_load_size
;
int
n_tiles_in_tb
=
TB_DIM
*
vec_load_size
;
dim3
num_blocks
(
DIVUP
(
num_tiles_k
,
TB_DIM
),
DIVUP
(
num_tiles_m
,
vec_load_size
));
dim3
num_blocks
(
DIVUP
(
num_tiles_k
,
TB_DIM
),
DIVUP
(
num_tiles_m
,
vec_load_size
));
int
slm_size
=
n_tiles_in_tb
*
SF_TILE_DIM_M
*
SF_TILE_DIM_K
*
sizeof
(
int8_t
);
int
slm_size
=
n_tiles_in_tb
*
SF_TILE_DIM_M
*
SF_TILE_DIM_K
*
sizeof
(
int8_t
);
const
int
original_M
=
input
->
flat_last_dim
();
const
int
original_K
=
input
->
flat_first_dim
()
/
MXFP8_BLOCK_SIZE
;
switch
(
vec_load_size
)
{
switch
(
vec_load_size
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
case
4
:
case
4
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncSetAttribute
((
const
void
*
)
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
);
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
break
;
case
2
:
case
2
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncSetAttribute
((
const
void
*
)
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
);
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
break
;
case
1
:
case
1
:
cudaFuncSetAttribute
((
const
void
*
)
swizzle_col_scaling_kernel
_
int
<
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncSetAttribute
((
const
void
*
)
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_col_scaling_kernel_int
<
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
);
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
break
;
#else
#else
case
4
:
case
4
:
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
);
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
break
;
case
2
:
case
2
:
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
);
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
break
;
case
1
:
case
1
:
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncSetAttribute
(
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
input
->
columnwise_scale_inv
.
dptr
,
input
->
columnwise_scale_inv
.
dptr
,
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
);
output
->
columnwise_scale_inv
.
dptr
,
m
,
k
,
original_M
,
original_K
);
break
;
break
;
#endif
#endif
default:
default:
...
@@ -498,10 +506,260 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
...
@@ -498,10 +506,260 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
}
else
{
}
else
{
NVTE_ERROR
(
"Not implemented for scaling_mode "
+
to_string
(
input
->
scaling_mode
)
+
", trans."
);
NVTE_ERROR
(
"Not implemented for scaling_mode "
+
to_string
(
input
->
scaling_mode
)
+
", trans."
);
}
}
cudaError_t
err
=
cudaGetLastError
();
if
(
err
!=
cudaSuccess
)
{
NVTE_CHECK_CUDA
(
cudaGetLastError
());
printf
(
"CUDA Error: %s
\n
"
,
cudaGetErrorString
(
err
));
}
exit
(
-
1
);
template
<
int
SF_TILE_DIM_M
,
int
SF_TILE_DIM_K
>
void
launch_multi_tensor_swizzle_scaling_factors
(
MultiSwizzleArgs
&
kernel_args
,
const
int
vec_load_size
,
const
bool
is_rowwise
,
cudaStream_t
stream
)
{
int
n_tiles_in_tb
=
TB_DIM
*
vec_load_size
;
int
slm_size
=
n_tiles_in_tb
*
SF_TILE_DIM_M
*
SF_TILE_DIM_K
*
sizeof
(
int8_t
);
/* Calculate number of CUDA blocks needed for each tensor.
* We have to do it here because we have to iterate over all tensors in this batch to
* get the minimum vec_load_size.
*/
for
(
size_t
j
=
0
;
j
<
kernel_args
.
num_tensors
;
j
++
)
{
const
int
m
=
kernel_args
.
m_list
[
j
];
const
int
k
=
kernel_args
.
k_list
[
j
];
int
num_tiles_m
=
m
/
SF_TILE_DIM_M
;
int
num_tiles_k
=
k
/
SF_TILE_DIM_K
;
if
(
is_rowwise
)
{
kernel_args
.
block_range
[
j
+
1
]
=
kernel_args
.
block_range
[
j
]
+
DIVUP
(
num_tiles_k
,
n_tiles_in_tb
)
*
num_tiles_m
;
}
else
{
kernel_args
.
block_range
[
j
+
1
]
=
kernel_args
.
block_range
[
j
]
+
DIVUP
(
num_tiles_k
,
TB_DIM
)
*
DIVUP
(
num_tiles_m
,
vec_load_size
);
}
}
// Launch kernel
const
int
num_blocks
=
kernel_args
.
block_range
[
kernel_args
.
num_tensors
];
dim3
block_size
(
TB_DIM
,
TB_DIM
);
if
(
is_rowwise
)
{
switch
(
vec_load_size
)
{
#ifdef __HIP_PLATFORM_AMD__
case
4
:
cudaFuncSetAttribute
(
(
const
void
*
)
multi_tensor_swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
multi_tensor_swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
case
2
:
cudaFuncSetAttribute
(
(
const
void
*
)
multi_tensor_swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
multi_tensor_swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
case
1
:
cudaFuncSetAttribute
(
(
const
void
*
)
multi_tensor_swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
multi_tensor_swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
#else
case
4
:
cudaFuncSetAttribute
(
multi_tensor_swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
multi_tensor_swizzle_row_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
case
2
:
cudaFuncSetAttribute
(
multi_tensor_swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
multi_tensor_swizzle_row_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
case
1
:
cudaFuncSetAttribute
(
multi_tensor_swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
multi_tensor_swizzle_row_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
#endif
default:
NVTE_ERROR
(
"Not valid vec_load_size."
);
break
;
}
}
else
{
switch
(
vec_load_size
)
{
#ifdef __HIP_PLATFORM_AMD__
case
4
:
cudaFuncSetAttribute
(
(
const
void
*
)
multi_tensor_swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
multi_tensor_swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
case
2
:
cudaFuncSetAttribute
(
(
const
void
*
)
multi_tensor_swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
multi_tensor_swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
case
1
:
cudaFuncSetAttribute
(
(
const
void
*
)
multi_tensor_swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
multi_tensor_swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
#else
case
4
:
cudaFuncSetAttribute
(
multi_tensor_swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
multi_tensor_swizzle_col_scaling_kernel
<
int4
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
case
2
:
cudaFuncSetAttribute
(
multi_tensor_swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
multi_tensor_swizzle_col_scaling_kernel
<
int2
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
case
1
:
cudaFuncSetAttribute
(
multi_tensor_swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
slm_size
);
multi_tensor_swizzle_col_scaling_kernel
<
int
,
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
<<<
num_blocks
,
block_size
,
slm_size
,
stream
>>>
(
kernel_args
);
break
;
#endif
default:
NVTE_ERROR
(
"Not valid vec_load_size."
);
break
;
}
}
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
void
multi_tensor_swizzle_scaling_factors
(
const
std
::
vector
<
Tensor
*>&
input
,
std
::
vector
<
Tensor
*>&
output
,
cudaStream_t
stream
)
{
auto
num_tensors
=
input
.
size
();
bool
all_has_data
=
true
;
bool
all_has_columnwise_data
=
true
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
i
++
)
{
if
(
!
is_fp8_dtype
(
input
[
i
]
->
dtype
())
||
!
is_mxfp_scaling
(
input
[
i
]
->
scaling_mode
))
{
NVTE_ERROR
(
"Not implemented caling mode "
+
to_string
(
input
[
i
]
->
scaling_mode
)
+
"."
);
}
// We don't allow empty tensors. They should be filtered out before calling this function.
if
(
input
[
i
]
->
data
.
numel
()
==
0
)
{
NVTE_ERROR
(
"Tensor input["
+
std
::
to_string
(
i
)
+
"] is empty."
);
}
CheckInputTensor
(
*
input
[
i
],
"scaling_factor_input["
+
std
::
to_string
(
i
)
+
"]"
);
CheckInputTensor
(
*
output
[
i
],
"scaling_factor_output["
+
std
::
to_string
(
i
)
+
"]"
);
all_has_data
&=
input
[
i
]
->
has_data
();
all_has_columnwise_data
&=
input
[
i
]
->
has_columnwise_data
();
}
NVTE_CHECK
(
all_has_data
||
all_has_columnwise_data
,
"All tensors should have data or columnwise data."
);
constexpr
int
SF_TILE_DIM_M
=
128
;
constexpr
int
SF_TILE_DIM_K
=
4
;
if
(
all_has_data
)
{
MultiSwizzleArgs
kernel_args
;
kernel_args
.
num_tensors
=
0
;
kernel_args
.
block_range
[
0
]
=
0
;
int
vec_load_size
=
4
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
i
++
)
{
//Launch kernel if argument struct is full
if
(
kernel_args
.
num_tensors
==
kMaxTensorsPerKernel
)
{
// There is no int3 and misaligned if using int4/int2.
if
(
vec_load_size
==
3
)
vec_load_size
=
1
;
launch_multi_tensor_swizzle_scaling_factors
<
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
(
kernel_args
,
vec_load_size
,
true
,
stream
);
// Reset the argument struct and vec_load_size
kernel_args
.
num_tensors
=
0
;
vec_load_size
=
4
;
}
const
int
m
=
input
[
i
]
->
scale_inv
.
shape
[
0
];
const
int
k
=
input
[
i
]
->
scale_inv
.
shape
[
1
];
NVTE_CHECK
(
m
%
SF_TILE_DIM_M
==
0
,
"Input should be padded in M/N dimension!"
);
NVTE_CHECK
(
k
%
SF_TILE_DIM_K
==
0
,
"Input should be padded in K dimension!"
);
NVTE_CHECK
(
k
>
0
,
"Input scale inverse should be 2D!"
);
NVTE_CHECK
(
m
*
k
==
std
::
accumulate
(
output
[
i
]
->
scale_inv
.
shape
.
begin
(),
output
[
i
]
->
scale_inv
.
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
()),
"Input.scale_inv size is not equal to Output.scale_inv size!"
);
int
num_tiles_k
=
k
/
SF_TILE_DIM_K
;
int
vec_load_size_i
=
(
num_tiles_k
-
1
)
%
4
+
1
;
// We use the minimum vec_load_size across all tensors.
vec_load_size
=
std
::
min
(
vec_load_size
,
vec_load_size_i
);
const
int
pos
=
kernel_args
.
num_tensors
;
kernel_args
.
input_list
[
pos
]
=
const_cast
<
void
*>
(
input
[
i
]
->
scale_inv
.
dptr
);
kernel_args
.
output_list
[
pos
]
=
output
[
i
]
->
scale_inv
.
dptr
;
kernel_args
.
m_list
[
pos
]
=
m
;
kernel_args
.
k_list
[
pos
]
=
k
;
kernel_args
.
original_m_list
[
pos
]
=
input
[
i
]
->
flat_first_dim
();
kernel_args
.
original_k_list
[
pos
]
=
input
[
i
]
->
flat_last_dim
()
/
MXFP8_BLOCK_SIZE
;
kernel_args
.
num_tensors
++
;
}
// Launch the remaining tensors
// There is no int3 and misaligned if using int4/int2.
if
(
vec_load_size
==
3
)
vec_load_size
=
1
;
launch_multi_tensor_swizzle_scaling_factors
<
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
(
kernel_args
,
vec_load_size
,
true
,
stream
);
}
if
(
all_has_columnwise_data
)
{
MultiSwizzleArgs
kernel_args
;
kernel_args
.
num_tensors
=
0
;
kernel_args
.
block_range
[
0
]
=
0
;
int
vec_load_size
=
4
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
i
++
)
{
//Launch kernel if argument struct is full
if
(
kernel_args
.
num_tensors
==
kMaxTensorsPerKernel
)
{
// There is no int3 and misaligned if using int4/int2.
if
(
vec_load_size
==
3
)
vec_load_size
=
1
;
launch_multi_tensor_swizzle_scaling_factors
<
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
(
kernel_args
,
vec_load_size
,
false
,
stream
);
// Reset the argument struct and vec_load_size
kernel_args
.
num_tensors
=
0
;
vec_load_size
=
4
;
}
const
int
m
=
input
[
i
]
->
columnwise_scale_inv
.
shape
[
1
];
const
int
k
=
input
[
i
]
->
columnwise_scale_inv
.
shape
[
0
];
NVTE_CHECK
(
m
%
SF_TILE_DIM_M
==
0
,
"Input should be padded in M/N dimension!"
);
NVTE_CHECK
(
k
%
SF_TILE_DIM_K
==
0
,
"Input should be padded in K dimension!"
);
NVTE_CHECK
(
k
>
0
,
"Input scale inverse should be 2D!"
);
NVTE_CHECK
(
m
*
k
==
std
::
accumulate
(
output
[
i
]
->
columnwise_scale_inv
.
shape
.
begin
(),
output
[
i
]
->
columnwise_scale_inv
.
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
()),
"Input.columnwise_scale_inv size is not equal to "
"Output.columnwise_scale_inv size!"
);
int
num_tiles_k
=
k
/
SF_TILE_DIM_K
;
int
vec_load_size_i
=
(
num_tiles_k
-
1
)
%
4
+
1
;
// We use the minimum vec_load_size across all tensors.
vec_load_size
=
std
::
min
(
vec_load_size
,
vec_load_size_i
);
const
int
pos
=
kernel_args
.
num_tensors
;
kernel_args
.
input_list
[
pos
]
=
const_cast
<
void
*>
(
input
[
i
]
->
columnwise_scale_inv
.
dptr
);
kernel_args
.
output_list
[
pos
]
=
output
[
i
]
->
columnwise_scale_inv
.
dptr
;
kernel_args
.
m_list
[
pos
]
=
m
;
kernel_args
.
k_list
[
pos
]
=
k
;
kernel_args
.
original_m_list
[
pos
]
=
input
[
i
]
->
flat_last_dim
();
kernel_args
.
original_k_list
[
pos
]
=
input
[
i
]
->
flat_first_dim
()
/
MXFP8_BLOCK_SIZE
;
kernel_args
.
num_tensors
++
;
}
// Launch the remaining tensors
// There is no int3 and misaligned if using int4/int2.
if
(
vec_load_size
==
3
)
vec_load_size
=
1
;
launch_multi_tensor_swizzle_scaling_factors
<
SF_TILE_DIM_M
,
SF_TILE_DIM_K
>
(
kernel_args
,
vec_load_size
,
false
,
stream
);
}
}
}
}
}
// namespace transformer_engine
}
// namespace transformer_engine
...
@@ -516,3 +774,16 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud
...
@@ -516,3 +774,16 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
swizzle_scaling_factors
(
convertNVTETensorCheck
(
input
),
convertNVTETensorCheck
(
output
),
stream
);
swizzle_scaling_factors
(
convertNVTETensorCheck
(
input
),
convertNVTETensorCheck
(
output
),
stream
);
}
}
void
nvte_multi_tensor_swizzle_scaling_factors
(
const
NVTETensor
*
inputs
,
NVTETensor
*
outputs
,
const
size_t
num_tensors
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_tensor_swizzle_scaling_factors
);
using
namespace
transformer_engine
;
NVTE_CHECK
(
num_tensors
>
0
,
"Number of tensors should be greater than 0."
);
std
::
vector
<
Tensor
*>
input_list
,
output_list
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
i
++
)
{
input_list
.
push_back
(
convertNVTETensorCheck
(
inputs
[
i
]));
output_list
.
push_back
(
convertNVTETensorCheck
(
outputs
[
i
]));
}
multi_tensor_swizzle_scaling_factors
(
input_list
,
output_list
,
stream
);
}
transformer_engine/common/transformer_engine.cpp
View file @
87e3e56e
...
@@ -197,7 +197,8 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
...
@@ -197,7 +197,8 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
}
}
}
else
{
}
else
{
NVTE_CHECK
(
t
.
scale
.
dptr
==
nullptr
,
"Scale is not supported for non-FP8 output "
,
name
);
NVTE_CHECK
(
t
.
scale
.
dptr
==
nullptr
,
"Scale is not supported for non-FP8 output "
,
name
);
NVTE_CHECK
(
t
.
amax
.
dptr
==
nullptr
,
"Amax is not supported for non-FP8 output "
,
name
);
// Note: amax is supported for non-FP8 output as it can be fused into the computation
// and later used for quantization with no need to compute it separately
NVTE_CHECK
(
t
.
scale_inv
.
dptr
==
nullptr
,
"Scale_inv is not supported for non-FP8 output "
,
name
);
NVTE_CHECK
(
t
.
scale_inv
.
dptr
==
nullptr
,
"Scale_inv is not supported for non-FP8 output "
,
name
);
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dptr
==
nullptr
,
NVTE_CHECK
(
t
.
columnwise_scale_inv
.
dptr
==
nullptr
,
"Scale_inv is not supported for non-FP8 input "
,
name
);
"Scale_inv is not supported for non-FP8 input "
,
name
);
...
...
transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu
View file @
87e3e56e
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "common/common.h"
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/recipe/recipe_common.cuh"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "common/utils.cuh"
...
@@ -184,6 +185,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
...
@@ -184,6 +185,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
}
}
}
}
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
// store to global memory.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion
();
#endif
// Step 3: Store cast output, Step 4: do transpose within thread tile
// Step 3: Store cast output, Step 4: do transpose within thread tile
OVecCast
tmp_output_c
;
OVecCast
tmp_output_c
;
...
@@ -419,6 +426,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
...
@@ -419,6 +426,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
}
}
}
}
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
// store to global memory.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion
();
#endif
// Step 3: Store cast output, Step 4: do transpose within thread tile
// Step 3: Store cast output, Step 4: do transpose within thread tile
// Edge case: in the non-full tile case, there are three subcases
// Edge case: in the non-full tile case, there are three subcases
// for full thread tile, it's the same thing here
// for full thread tile, it's the same thing here
...
@@ -883,6 +896,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
...
@@ -883,6 +896,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
const
bool
return_transpose
,
const
bool
pow_2_scale
,
const
bool
return_transpose
,
const
bool
pow_2_scale
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
quantize_transpose_square_blockwise
);
NVTE_API_CALL
(
quantize_transpose_square_blockwise
);
checkCuDriverContext
(
stream
);
NVTE_CHECK
(
input
.
shape
==
output
.
shape
,
"Input and output must have the same shape."
);
NVTE_CHECK
(
input
.
shape
==
output
.
shape
,
"Input and output must have the same shape."
);
const
size_t
row_length
=
input
.
shape
.
size
()
>
0
?
input
.
shape
.
at
(
input
.
shape
.
size
()
-
1
)
:
1u
;
const
size_t
row_length
=
input
.
shape
.
size
()
>
0
?
input
.
shape
.
at
(
input
.
shape
.
size
()
-
1
)
:
1u
;
size_t
num_rows
=
1
;
size_t
num_rows
=
1
;
...
@@ -917,9 +932,23 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
...
@@ -917,9 +932,23 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
scale_t_stride_y
=
scale_inv_t
.
shape
[
1
];
scale_t_stride_y
=
scale_inv_t
.
shape
[
1
];
}
}
#ifdef __HIP_PLATFORM_AMD__
const
size_t
block_len
=
blockwise_fp8_block_len
();
const
size_t
block_len
=
blockwise_fp8_block_len
();
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
block_len
);
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
block_len
);
const
size_t
num_blocks_y
=
DIVUP
(
num_rows
,
block_len
);
const
size_t
num_blocks_y
=
DIVUP
(
num_rows
,
block_len
);
#else
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
BLOCK_TILE_DIM
);
const
size_t
num_blocks_y
=
DIVUP
(
num_rows
,
BLOCK_TILE_DIM
);
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
cudaLaunchAttribute
attribute
[
1
];
attribute
[
0
].
id
=
cudaLaunchAttributeProgrammaticStreamSerialization
;
attribute
[
0
].
val
.
programmaticStreamSerializationAllowed
=
1
;
cudaLaunchConfig_t
cfg
=
{
grid
,
THREADS_PER_BLOCK
,
0
,
stream
,
NULL
,
0
};
if
(
transformer_engine
::
cuda
::
sm_arch
(
transformer_engine
::
cuda
::
current_device
())
>=
90
)
{
cfg
.
attrs
=
attribute
;
cfg
.
numAttrs
=
1
;
}
#endif
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
dtype
,
InputType
,
input
.
dtype
,
InputType
,
...
@@ -929,10 +958,13 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
...
@@ -929,10 +958,13 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
return_transpose
,
kReturnTranspose
,
return_transpose
,
kReturnTranspose
,
#ifdef __HIP_PLATFORM_AMD__
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
const
bool
full_tile
=
row_length
%
block_len
==
0
&&
num_rows
%
block_len
==
0
;
const
bool
full_tile
=
row_length
%
block_len
==
0
&&
num_rows
%
block_len
==
0
;
#else
const
bool
full_tile
=
row_length
%
BLOCK_TILE_DIM
==
0
&&
num_rows
%
BLOCK_TILE_DIM
==
0
;
#endif
if
(
full_tile
)
{
if
(
full_tile
)
{
#ifndef __HIP_PLATFORM_AMD__
#ifndef __HIP_PLATFORM_AMD__
CUtensorMap
tensor_map_output_trans
;
CUtensorMap
tensor_map_output_trans
;
...
@@ -940,15 +972,28 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
...
@@ -940,15 +972,28 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
tensor_map_output_trans
=
tensor_map_output_trans
=
get_tensor_map
<
OutputType
>
(
output_t
,
num_rows
,
row_length
);
get_tensor_map
<
OutputType
>
(
output_t
,
num_rows
,
row_length
);
}
}
block_scaled_cast_transpose_kernel
<
kReturnTranspose
,
float
,
InputType
,
OutputType
>
cudaLaunchKernelEx
(
&
cfg
,
<<<
grid
,
THREADS_PER_BLOCK
,
0
,
stream
>>>
(
block_scaled_cast_transpose_kernel
<
kReturnTranspose
,
float
,
InputType
,
OutputType
>
,
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
tensor_map_output_trans
,
pow_2_scale
);
}
else
{
cudaLaunchKernelEx
(
&
cfg
,
block_scaled_cast_transpose_kernel_notaligned
<
kReturnTranspose
,
float
,
InputType
,
OutputType
>
,
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
tensor_map_output_trans
,
pow_2_scale
);
pow_2_scale
);
#else
#else
while
(
true
)
{
while
(
true
)
{
if
(
128
==
block_len
)
{
if
(
128
==
block_len
)
{
...
@@ -978,7 +1023,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
...
@@ -978,7 +1023,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
break
;
break
;
}
}
}
}
#endif
}
else
{
}
else
{
while
(
true
)
{
while
(
true
)
{
if
(
128
==
block_len
)
{
if
(
128
==
block_len
)
{
...
@@ -1008,6 +1052,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
...
@@ -1008,6 +1052,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
break
;
break
;
}
}
}
}
#endif
}
// full-tile
}
// full-tile
)
// return_transpose
)
// return_transpose
)
// OutputType
)
// OutputType
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
View file @
87e3e56e
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include "common/common.h"
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/transpose/cast_transpose.h"
#include "common/util/cuda_runtime.h"
#include "common/utils.cuh"
#include "common/utils.cuh"
namespace
transformer_engine
{
namespace
transformer_engine
{
...
@@ -74,7 +75,7 @@ Step 2: Cast and store to output_c
...
@@ -74,7 +75,7 @@ Step 2: Cast and store to output_c
* What each thread does in each loop:
* What each thread does in each loop:
* 2 elements are read from the shared memory at a time, for a total of 8 times
* 2 elements are read from the shared memory at a time, for a total of 8 times
* Every 8 consecutive threads do reduction and calculate the amax of each row
* Every 8 consecutive threads do reduction and calculate the amax of each row
* 16 elements are quantized and write to output_c at a time
* 16 elements are quantized and w
s
rite to output_c at a time
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 |
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 |
| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
...
@@ -251,6 +252,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
...
@@ -251,6 +252,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
__syncthreads
();
__syncthreads
();
// If not return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return rowwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if
(
!
return_columnwise_gemm_ready
&&
!
return_columnwise_compact
)
{
cudaTriggerProgrammaticLaunchCompletion
();
}
#endif
// Step 2: Cast and store to output_c
// Step 2: Cast and store to output_c
if
(
return_rowwise
)
{
if
(
return_rowwise
)
{
constexpr
int
r_stride
=
constexpr
int
r_stride
=
...
@@ -356,6 +365,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
...
@@ -356,6 +365,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}
}
}
// If return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return columnwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if
(
return_columnwise_gemm_ready
||
return_columnwise_compact
)
{
cudaTriggerProgrammaticLaunchCompletion
();
}
#endif
// Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t
// Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t
if
(
return_columnwise_gemm_ready
)
{
if
(
return_columnwise_gemm_ready
)
{
constexpr
int
c_stride
=
constexpr
int
c_stride
=
...
@@ -1424,20 +1441,30 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
...
@@ -1424,20 +1441,30 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
scale_t_stride_y
=
columnwise_compact
?
scale_t_k
:
1
;
scale_t_stride_y
=
columnwise_compact
?
scale_t_k
:
1
;
}
}
#ifdef __HIP_PLATFORM_AMD__
const
size_t
block_len
=
blockwise_fp8_block_len
();
const
size_t
block_len
=
blockwise_fp8_block_len
();
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
(
size_t
)
block_len
);
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
(
size_t
)
block_len
);
const
size_t
num_blocks_y
=
DIVUP
(
num_rows
,
(
size_t
)
block_len
);
const
size_t
num_blocks_y
=
DIVUP
(
num_rows
,
(
size_t
)
block_len
);
#else
const
size_t
num_blocks_x
=
DIVUP
(
row_length
,
(
size_t
)
kTileDim
);
const
size_t
num_blocks_y
=
DIVUP
(
num_rows
,
(
size_t
)
kTileDim
);
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
cudaLaunchAttribute
attribute
[
1
];
attribute
[
0
].
id
=
cudaLaunchAttributeProgrammaticStreamSerialization
;
attribute
[
0
].
val
.
programmaticStreamSerializationAllowed
=
1
;
#endif
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
dtype
,
InputType
,
input
.
dtype
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT
(
output
.
dtype
,
OutputType
,
output
.
dtype
,
OutputType
,
#ifdef __HIP_PLATFORM_AMD__
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
dim3
grid
(
num_blocks_x
,
num_blocks_y
,
1
);
const
bool
full_tile
=
row_length
%
block_len
==
0
&&
num_rows
%
block_len
==
0
;
const
bool
full_tile
=
row_length
%
block_len
==
0
&&
num_rows
%
block_len
==
0
;
#else
const
bool
full_tile
=
row_length
%
kTileDim
==
0
&&
num_rows
%
kTileDim
==
0
;
#endif
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
full_tile
,
kAligned
,
full_tile
,
kAligned
,
...
@@ -1505,21 +1532,30 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
...
@@ -1505,21 +1532,30 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
}
}
#else
#else
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
InputType
);
size_t
smem_bytes
=
kSMemSize
*
sizeof
(
InputType
);
cudaLaunchConfig_t
cfg
=
{
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
,
NULL
,
0
};
if
(
transformer_engine
::
cuda
::
sm_arch
(
transformer_engine
::
cuda
::
current_device
())
>=
90
)
{
cfg
.
attrs
=
attribute
;
cfg
.
numAttrs
=
1
;
}
// shared memory must be requested up
// shared memory must be requested up
if
(
smem_bytes
>=
48
*
1024
)
{
if
(
smem_bytes
>=
48
*
1024
)
{
cudaError_t
err
=
cudaFuncSetAttribute
(
cudaError_t
err
=
cudaFuncSetAttribute
(
&
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
,
&
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_bytes
);
NVTE_CHECK
(
err
==
cudaSuccess
,
"Failed to set dynamic shared memory size."
);
NVTE_CHECK
(
err
==
cudaSuccess
,
"Failed to set dynamic shared memory size."
);
}
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
}
cudaLaunchKernelEx
(
&
cfg
,
<<<
grid
,
kThreadsPerBlock
,
smem_bytes
,
stream
>>>
(
block_scaled_1d_cast_transpose_kernel
<
kAligned
,
float
,
InputType
,
OutputType
>
,
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
const
InputType
*>
(
input
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
OutputType
*>
(
output_t
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv
.
dptr
),
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_x
,
reinterpret_cast
<
float
*>
(
scale_inv_t
.
dptr
),
row_length
,
num_rows
,
scale_stride_y
,
scale_t_stride_x
,
scale_t_stride_y
,
epsilon
,
rowwise_option
,
scale_stride_x
,
scale_stride_y
,
scale_t_stride_x
,
columnwise_option
,
pow2_scale
);
scale_t_stride_y
,
epsilon
,
rowwise_option
,
columnwise_option
,
pow2_scale
);
#endif
#endif
)
// kAligned
)
// kAligned
)
// OutputType
)
// OutputType
...
...
Prev
1
2
3
4
5
6
7
8
9
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment