Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
...@@ -562,5 +562,75 @@ size_t get_max_tokens(size_t num_tokens) { ...@@ -562,5 +562,75 @@ size_t get_max_tokens(size_t num_tokens) {
return max_t; return max_t;
} }
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
int64_t offset) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid > 0) return;
rng_state_dst[0] = seed[0];
rng_state_dst[1] = offset;
}
__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) return;
if (cu_seqlen[tid] > 0) {
// atomicAdd only support 32 bits dtype
atomicAdd(out, 1);
}
}
void PopulateRngStateAsync(void *rng_state_dst, const void *seed, size_t q_max_seqlen,
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream) {
size_t increment = 0;
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
increment = 16;
} else {
constexpr int threads_per_cta = 128;
increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta;
}
auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment);
populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast<int64_t *>(rng_state_dst),
reinterpret_cast<const int64_t *>(seed), offset);
NVTE_CHECK_CUDA(cudaGetLastError());
}
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream) {
// workspace size requires 4 bytes
uint32_t *dout = static_cast<uint32_t *>(workspace);
uint32_t hout{};
cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream);
constexpr int threads = 128;
const int blocks = (len - 1) / threads + 1;
get_runtime_num_segments_kernel<<<blocks, threads, 0, stream>>>(static_cast<int32_t *>(cu_seqlen),
len, dout);
cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
return hout;
}
__global__ void extract_seed_and_offset(int64_t *rng_state_ptr, bool captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph) {
if (captured) {
rng_state_ptr[0] = *seed_ptr;
rng_state_ptr[1] = static_cast<int64_t>(*offset_ptr + static_cast<int64_t>(offset_intragraph));
} else {
rng_state_ptr[0] = static_cast<int64_t>(seed_val);
rng_state_ptr[1] = static_cast<int64_t>(offset_val);
}
}
} // namespace fused_attn } // namespace fused_attn
} // namespace transformer_engine } // namespace transformer_engine
void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph, cudaStream_t stream) {
NVTE_API_CALL(nvte_extract_seed_and_offset);
using namespace transformer_engine;
fused_attn::extract_seed_and_offset<<<1, 1, 0, stream>>>(
rng_state_ptr, captured, seed_ptr, seed_val, offset_ptr, offset_val, offset_intragraph);
}
...@@ -150,6 +150,38 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at ...@@ -150,6 +150,38 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at
size_t get_max_batch_size(size_t batch_size); size_t get_max_batch_size(size_t batch_size);
size_t get_max_tokens(size_t num_tokens); size_t get_max_tokens(size_t num_tokens);
class FusedAttnOffsetManager {
public:
static FusedAttnOffsetManager &Instance() {
static thread_local FusedAttnOffsetManager instance;
return instance;
}
size_t GetAndUpdateOffset(size_t increment) {
size_t ret = offset_;
offset_ += increment;
return ret;
}
FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete;
void operator=(FusedAttnOffsetManager const &) = delete;
private:
FusedAttnOffsetManager() {}
size_t offset_ = 0;
};
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
int64_t offset);
__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out);
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream);
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);
} // namespace fused_attn } // namespace fused_attn
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -115,10 +115,10 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq ...@@ -115,10 +115,10 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
template <typename scalar_t> template <typename scalar_t>
__global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seqlens, __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seqlens,
const float *freqs, scalar_t *dst, const bool interleaved, const float *freqs, const int *start_positions,
const int cp_size, const int cp_rank, const int s, scalar_t *dst, const bool interleaved, const int cp_size,
const int h, const int d, const int d2, const int cp_rank, const int s, const int h, const int d,
const int stride_s_or_t, const int stride_b, const int d2, const int stride_s_or_t, const int stride_b,
const int stride_h, const int stride_d, const int stride_h, const int stride_d,
const int o_stride_s_or_t, const int o_stride_b, const int o_stride_s_or_t, const int o_stride_b,
const int o_stride_h, const int o_stride_d) { const int o_stride_h, const int o_stride_d) {
...@@ -149,7 +149,8 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq ...@@ -149,7 +149,8 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
} }
} else { } else {
s_id_for_freqs = s_id; int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id];
s_id_for_freqs = s_id + begin_offset;
} }
fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block,
...@@ -199,11 +200,12 @@ __global__ void fused_rope_backward_kernel( ...@@ -199,11 +200,12 @@ __global__ void fused_rope_backward_kernel(
template <typename scalar_t> template <typename scalar_t>
void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs, void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs,
scalar_t *output, const NVTE_QKV_Format qkv_format, const int *start_positions, scalar_t *output,
const bool interleaved, const int cp_size, const int cp_rank, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int s, const int b, const int h, const int d, const int d2, const int cp_size, const int cp_rank, const int s, const int b,
const int stride_s_or_t, const int stride_b, const int stride_h, const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_d, cudaStream_t stream) { const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8; int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b); dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block); dim3 threads(THREADS_PER_WARP, warps_per_block);
...@@ -223,8 +225,9 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c ...@@ -223,8 +225,9 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c
const int o_stride_d = 1; const int o_stride_d = 1;
fused_rope_forward_kernel<<<blocks, threads, 0, stream>>>( fused_rope_forward_kernel<<<blocks, threads, 0, stream>>>(
input, cu_seqlens, freqs, output, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t, input, cu_seqlens, freqs, start_positions, output, interleaved, cp_size, cp_rank, s, h, d, d2,
stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d); stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h,
o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
...@@ -262,15 +265,17 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se ...@@ -262,15 +265,17 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
} }
void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs, void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs,
Tensor *output, const NVTE_QKV_Format qkv_format, const bool interleaved, const Tensor &start_positions, Tensor *output,
const int cp_size, const int cp_rank, const int s, const int b, const int h, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size,
const int d, const int d2, const int stride_s_or_t, const int stride_b, const int cp_rank, const int s, const int b, const int h, const int d,
const int d2, const int stride_s_or_t, const int stride_b,
const int stride_h, const int stride_d, cudaStream_t stream) { const int stride_h, const int stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t, input.data.dtype, scalar_t,
fused_rope_forward_launcher(reinterpret_cast<const scalar_t *>(input.data.dptr), fused_rope_forward_launcher(reinterpret_cast<const scalar_t *>(input.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr), reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr), reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<const int *>(start_positions.data.dptr),
reinterpret_cast<scalar_t *>(output->data.dptr), qkv_format, reinterpret_cast<scalar_t *>(output->data.dptr), qkv_format,
interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream);); stride_b, stride_h, stride_d, stream););
...@@ -295,19 +300,19 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c ...@@ -295,19 +300,19 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c
} // end namespace transformer_engine } // end namespace transformer_engine
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output, const NVTETensor freqs, const NVTETensor start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved, NVTETensor output, const NVTE_QKV_Format qkv_format,
const int cp_size, const int cp_rank, const int s, const int b, const bool interleaved, const int cp_size, const int cp_rank,
const int h, const int d, const int d2, const int stride_s_or_t, const int s, const int b, const int h, const int d, const int d2,
const int stride_b, const int stride_h, const int stride_d, const int stride_s_or_t, const int stride_b, const int stride_h,
cudaStream_t stream) { const int stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_forward); NVTE_API_CALL(nvte_fused_rope_forward);
using namespace transformer_engine; using namespace transformer_engine;
fused_rope_forward(*reinterpret_cast<const Tensor *>(input), fused_rope_forward(
*reinterpret_cast<const Tensor *>(cu_seqlens), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs), reinterpret_cast<Tensor *>(output), *reinterpret_cast<const Tensor *>(freqs), *reinterpret_cast<const Tensor *>(start_positions),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, reinterpret_cast<Tensor *>(output), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2,
stride_b, stride_h, stride_d, stream); stride_s_or_t, stride_b, stride_h, stride_d, stream);
} }
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
......
...@@ -11,8 +11,7 @@ ...@@ -11,8 +11,7 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ #ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ #define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#include <cstdint> #include "stdint.h"
#include "transformer_engine.h" #include "transformer_engine.h"
#ifdef __cplusplus #ifdef __cplusplus
...@@ -245,7 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -245,7 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor rng_state, size_t max_seqlen, bool is_training, const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
...@@ -301,7 +300,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -301,7 +300,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
*/ */
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensor S, NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQKV, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens, NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, size_t max_seqlen, const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
...@@ -369,7 +368,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con ...@@ -369,7 +368,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
*/ */
void nvte_fused_attn_fwd_kvpacked( void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
...@@ -430,7 +429,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -430,7 +429,7 @@ void nvte_fused_attn_fwd_kvpacked(
*/ */
void nvte_fused_attn_bwd_kvpacked( void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
...@@ -501,7 +500,7 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -501,7 +500,7 @@ void nvte_fused_attn_bwd_kvpacked(
*/ */
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, const NVTETensor page_table_v, const NVTETensor rng_state,
...@@ -570,7 +569,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -570,7 +569,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
*/ */
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
...@@ -580,6 +579,76 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -580,6 +579,76 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
int64_t window_size_right, bool deterministic, NVTETensor workspace, int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Update the RNG state with the seed and calculated offset.
*
* \param[in] rng_state_dst RNG state to store seed and offset.
* \param[in] seed Seed for RNG state.
* \param[in] q_max_seqlen Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] kv_max_seqlen Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] backend Fused attention backend.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor seed,
size_t q_max_seqlen, size_t kv_max_seqlen,
NVTE_Fused_Attn_Backend backend, cudaStream_t stream);
/*! \brief Get KV format for a given QKV layout.
*
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] workspace Workspace tensor.
* \param[in] len batch_size x sequence_length.
* \param[in] stream CUDA stream used for this operation.
*/
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len,
cudaStream_t stream);
void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph, cudaStream_t stream);
void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache,
NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens,
NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
int is_non_paged, cudaStream_t stream);
void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu_seqlens,
NVTETensor half, int half_idx, cudaStream_t stream);
void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int lse_packed,
cudaStream_t stream);
void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens,
NVTETensor half_lse, int lse_packed,
int second_half_lse_seqlen, cudaStream_t stream);
void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
const NVTETensor &lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int only_second_half, int lse_packed,
cudaStream_t stream);
void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step,
const NVTETensor &cu_seqlens, const char *first_half,
const char *second_half, cudaStream_t stream);
void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output,
int total_tokens, int world_size, int rank,
cudaStream_t stream);
void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int b, int max_seq_len, cudaStream_t stream);
void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int t, cudaStream_t stream);
void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream);
void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -20,6 +20,7 @@ extern "C" { ...@@ -20,6 +20,7 @@ extern "C" {
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats) * (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor. * \param[in] freqs The freqs tensor.
* \param[in] start_positions The beginning offsets for applying RoPE embeddings.
* \param[out] output Output tensor. * \param[out] output Output tensor.
* \param[in] qkv_format QKV format. * \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding. * \param[in] interleaved Whether to use interleaved rotary position embedding.
...@@ -37,12 +38,12 @@ extern "C" { ...@@ -37,12 +38,12 @@ extern "C" {
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output, const NVTETensor freqs, const NVTETensor start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved, NVTETensor output, const NVTE_QKV_Format qkv_format,
const int cp_size, const int cp_rank, const int s, const int b, const bool interleaved, const int cp_size, const int cp_rank,
const int h, const int d, const int d2, const int stride_s_or_t, const int s, const int b, const int h, const int d, const int d2,
const int stride_b, const int stride_h, const int stride_d, const int stride_s_or_t, const int stride_b, const int stride_h,
cudaStream_t stream); const int stride_d, cudaStream_t stream);
/*! \brief Compute the backward of the fused rope. /*! \brief Compute the backward of the fused rope.
* *
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file multi_tensor.h
* \brief Functions handling multi tensor kernels.
*/
#ifndef TRANSFORMER_ENGINE_MULTI_TENSOR_H_
#define TRANSFORMER_ENGINE_MULTI_TENSOR_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, int per_tensor,
int max_chunks_per_tensor, const int device_id,
cudaStream_t stream);
void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor output,
NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, NVTETensor inv_scale,
int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream);
void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream);
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon,
const int device_id, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_MULTI_TENSOR_H_
...@@ -18,4 +18,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id ...@@ -18,4 +18,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id
const NVTETensor prob, const int num_rows, const int topK, const int num_cols, const NVTETensor prob, const int num_rows, const int topK, const int num_cols,
cudaStream_t stream = nullptr); cudaStream_t stream = nullptr);
void nvte_device_radix_sort_pairs(void *temp_storage, size_t *temp_storage_bytes, int *keys_in,
int *keys_out, int *values_in, int *values_out, size_t num_items);
#endif // TRANSFORMER_ENGINE_PERMUTATION_H_ #endif // TRANSFORMER_ENGINE_PERMUTATION_H_
...@@ -96,6 +96,17 @@ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t s ...@@ -96,6 +96,17 @@ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t s
void nvte_compute_scale_from_amax(NVTETensor output, const NVTEQuantizationConfig config, void nvte_compute_scale_from_amax(NVTETensor output, const NVTEQuantizationConfig config,
cudaStream_t stream); cudaStream_t stream);
void nvte_fp8_block_scaling_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h,
size_t w, size_t amax_stride_h,
size_t amax_stride_w, size_t start_offset,
size_t block_len, cudaStream_t stream);
void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
const NVTETensor scale, size_t h, size_t w,
size_t scale_stride_h, size_t scale_stride_w,
size_t start_offset, size_t block_len,
const NVTEDType out_dtype, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -23,14 +23,15 @@ extern "C" { ...@@ -23,14 +23,15 @@ extern "C" {
*/ */
enum NVTEDType { enum NVTEDType {
kNVTEByte = 0, /*!< Byte */ kNVTEByte = 0, /*!< Byte */
kNVTEInt32 = 1, /*!< 32-bit integer */ kNVTEInt16 = 1, /*!< 16-bit integer */
kNVTEInt64 = 2, /*!< 64-bit integer */ kNVTEInt32 = 2, /*!< 32-bit integer */
kNVTEFloat32 = 3, /*!< 32-bit float */ kNVTEInt64 = 3, /*!< 64-bit integer */
kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */ kNVTEFloat32 = 4, /*!< 32-bit float */
kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */ kNVTEFloat16 = 5, /*!< 16-bit float (E5M10) */
kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */ kNVTEBFloat16 = 6, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */ kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */
kNVTEFloat8E8M0 = 8, /*!< 8-bit float (E8M0) */ kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */
kNVTENumTypes /*!< Number of supported types */ kNVTENumTypes /*!< Number of supported types */
}; };
...@@ -38,12 +39,10 @@ enum NVTEDType { ...@@ -38,12 +39,10 @@ enum NVTEDType {
* \brief Shape of the tensor. * \brief Shape of the tensor.
*/ */
struct NVTEShape { struct NVTEShape {
/*! \brief Shape data, of size ndim. */ /*! \brief Shape data, with ndim valid elements. */
const size_t *data; size_t data[15];
/*! \brief Number of dimensions. */ /*! \brief Number of dimensions. */
size_t ndim; size_t ndim;
/*! \brief Copy of data. Num dims limited to permit fixed struct size.*/
size_t owned_data[14];
}; };
/*! \struct NVTEBasicTensor /*! \struct NVTEBasicTensor
...@@ -343,6 +342,23 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, ...@@ -343,6 +342,23 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
*/ */
void nvte_destroy_quantization_config(NVTEQuantizationConfig config); void nvte_destroy_quantization_config(NVTEQuantizationConfig config);
/*! \brief Check if non-TN FP8 Gemm is supported.
*
* \return A flag which indicates whether non-TN FP8 Gemm is supported or not.
*/
int nvte_is_non_tn_fp8_gemm_supported();
/*! \brief Performs a memset of the data at the given pointer and size in bytes.
*
* \param[in] ptr Pointer to the memory to be set.
* \param[in] value Value to set the memory to.
* \param[in] size_in_bytes Size of the memory in bytes.
* \param[in] stream CUDA stream to use for the operation.
*
* This function calls a fill kernel for small sizes and calls cudaMemsetAsync for larger sizes.
*/
void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
...@@ -358,14 +374,15 @@ namespace transformer_engine { ...@@ -358,14 +374,15 @@ namespace transformer_engine {
*/ */
enum class DType { enum class DType {
kByte = 0, kByte = 0,
kInt32 = 1, kInt16 = 1,
kInt64 = 2, kInt32 = 2,
kFloat32 = 3, kInt64 = 3,
kFloat16 = 4, kFloat32 = 4,
kBFloat16 = 5, kFloat16 = 5,
kFloat8E4M3 = 6, kBFloat16 = 6,
kFloat8E5M2 = 7, kFloat8E4M3 = 7,
kFloat8E8M0 = 8, kFloat8E5M2 = 8,
kFloat8E8M0 = 9,
kNumTypes kNumTypes
}; };
...@@ -691,15 +708,10 @@ class TensorWrapper { ...@@ -691,15 +708,10 @@ class TensorWrapper {
static constexpr size_t defaultData = 1; static constexpr size_t defaultData = 1;
static constexpr NVTEShape defaultShape = { static constexpr NVTEShape defaultShape = {
&defaultData, 1, {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}; {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1};
private: private:
NVTEShape convertShape(const NVTEShape &s) { NVTEShape convertShape(const NVTEShape &s) { return s; }
NVTEShape ret = s;
// Move the ownership rather than pointing to the parent shape.
ret.data = ret.owned_data;
return ret;
}
NVTEShape convertShape(const std::vector<size_t> &s) { NVTEShape convertShape(const std::vector<size_t> &s) {
return nvte_make_shape(s.data(), s.size()); return nvte_make_shape(s.data(), s.size());
......
...@@ -4,23 +4,16 @@ ...@@ -4,23 +4,16 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#ifdef __HIP_PLATFORM_AMD__
#include "amd_detail/hip_float8.h"
#else
#include <cuda_fp8.h>
#endif
// Another possibility:
// #include <torch/all.h>
#include <assert.h> #include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "common/utils.cuh" #include "../utils.cuh"
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_adam {
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
...@@ -39,7 +32,6 @@ using fp8e5m2 = __nv_fp8_e5m2; ...@@ -39,7 +32,6 @@ using fp8e5m2 = __nv_fp8_e5m2;
using fp8e4m3 = te_hip_fp8_e4m3; using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = te_hip_fp8_e5m2; using fp8e5m2 = te_hip_fp8_e5m2;
#endif #endif
using transformer_engine::DType;
template <typename T> template <typename T>
struct is_fp8 : std::false_type {}; struct is_fp8 : std::false_type {};
...@@ -585,12 +577,13 @@ struct AdamCapturableMasterFunctor { ...@@ -585,12 +577,13 @@ struct AdamCapturableMasterFunctor {
} }
}; };
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr, std::vector<std::vector<Tensor *>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay) { const float weight_decay, const int device_id, cudaStream_t stream) {
using namespace at; const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
...@@ -601,10 +594,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -601,10 +594,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
size_t max_size = 0; size_t max_size = 0;
bool requires_64bit_indexing = false; bool requires_64bit_indexing = false;
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) { for (size_t i = 0; i < num_tensor_lists; i++) {
for (auto it2 = it->begin(); it2 != it->end(); it2++) { for (size_t j = 0; j < num_tensors_per_list; j++) {
if (it2->numel() > max_size) { if (tensor_lists[i][j]->numel() > max_size) {
max_size = it2->numel(); max_size = tensor_lists[i][j]->numel();
if (max_size >= INT_MAX) { if (max_size >= INT_MAX) {
requires_64bit_indexing = true; requires_64bit_indexing = true;
break; break;
...@@ -616,69 +609,70 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -616,69 +609,70 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
} }
} }
const auto g_in_type = tensor_lists[0][0].scalar_type(); const auto g_in_type_te = tensor_lists[0][0]->dtype();
const auto p_in_type = tensor_lists[1][0].scalar_type(); const auto p_in_type_te = tensor_lists[1][0]->dtype();
auto tl_size = tensor_lists.size();
// case 4: g, p, m, v // case 4: g, p, m, v
// case 5: g, p, m, v, p_master // case 5: g, p, m, v, p_master
TORCH_CHECK(tl_size == 4 || tl_size == 5, "tensor list must contain 4 or 5"); NVTE_CHECK(num_tensor_lists == 4 || num_tensor_lists == 5, "tensor list must contain 4 or 5");
if (requires_64bit_indexing) { if (requires_64bit_indexing) {
if (tl_size == 4) { if (num_tensor_lists == 4) {
// Assume single type across p,g,m1,m2 now // Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type, 0, "adam", p_in_type_te, p_in_type,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type, 1, "adam", g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 4>((int64_t)chunk_size, noop_flag, multi_tensor_apply<BLOCK_SIZE, 4>((int64_t)chunk_size, noop_flag,
tensor_lists, tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int64_t>(), beta1, AdamFunctor<p_in_type, g_in_type, float, int64_t>(), device_id,
beta2, bias_correction1, bias_correction2, epsilon, lr, stream, beta1, beta2, bias_correction1, bias_correction2,
(adamMode_t)mode, weight_decay);)); epsilon, lr, (adamMode_t)mode, weight_decay);));
} else { } else {
// g, p, m, v, p_master // g, p, m, v, p_master
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type, 0, "adam", p_in_type_te, p_in_type,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type, 1, "adam", g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5>((int64_t)chunk_size, noop_flag, multi_tensor_apply<BLOCK_SIZE, 5>(
tensor_lists, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int64_t>(), AdamFunctorMaster<p_in_type, g_in_type, float, int64_t>(), device_id, stream,
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
(adamMode_t)mode, weight_decay);)); weight_decay);));
} }
} else { } else {
if (tl_size == 4) { if (num_tensor_lists == 4) {
// Assume single type across p,g,m1,m2 now // Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type, 0, "adam", p_in_type_te, p_in_type,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type, 1, "adam", g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int32_t>(), beta1, AdamFunctor<p_in_type, g_in_type, float, int32_t>(), device_id,
beta2, bias_correction1, bias_correction2, epsilon, lr, stream, beta1, beta2, bias_correction1, bias_correction2,
(adamMode_t)mode, weight_decay);)); epsilon, lr, (adamMode_t)mode, weight_decay);));
} else { } else {
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type, 0, "adam", p_in_type_te, p_in_type,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type, 1, "adam", g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 5>(chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int32_t>(), AdamFunctorMaster<p_in_type, g_in_type, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, device_id, stream, beta1, beta2, bias_correction1,
(adamMode_t)mode, weight_decay);)); bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
} }
} }
AT_CUDA_CHECK(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<Tensor *>> tensor_lists,
const float lr, const float beta1, const float beta2, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay) { const int bias_correction, const float weight_decay,
using namespace at; const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
...@@ -687,34 +681,34 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag ...@@ -687,34 +681,34 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag
bias_correction2 = 1 - std::pow(beta2, step); bias_correction2 = 1 - std::pow(beta2, step);
} }
const auto g_in_type = tensor_lists[0][0].scalar_type(); const auto g_in_type_te = tensor_lists[0][0]->dtype();
const auto p_in_type = tensor_lists[1][0].scalar_type(); const auto p_in_type_te = tensor_lists[1][0]->dtype();
auto tl_size = tensor_lists.size();
// case 5: g, p, m, v, p_master // case 5: g, p, m, v, p_master
TORCH_CHECK(tl_size == 5, "tensor list must contain 5"); NVTE_CHECK(num_tensor_lists == 5, "tensor list must contain 5");
TORCH_CHECK(p_in_type == at::ScalarType::BFloat16, NVTE_CHECK(p_in_type_te == DType::kBFloat16,
"Adam with BF16 param remainders requires BF16 params"); "Adam with BF16 param remainders requires BF16 params");
// g, p, m, v, p_master // g, p, m, v, p_master
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam", TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( g_in_type_te, g_in_type,
g_in_type, 1, "adam", multi_tensor_apply<BLOCK_SIZE, 5>((int64_t)chunk_size, noop_flag, tensor_lists,
multi_tensor_apply<BLOCK_SIZE, 5>((int64_t)chunk_size, noop_flag, tensor_lists, AdamFunctorMasterParamRemainder<g_in_type, float, int64_t>(), device_id,
AdamFunctorMasterParamRemainder<scalar_t_1, float, int64_t>(), stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay););
(adamMode_t)mode, weight_decay);));
NVTE_CHECK_CUDA(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
} }
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr, std::vector<std::vector<Tensor *>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon, const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction, const int step, const int mode, const int bias_correction,
const float weight_decay, DType fp8_dtype) { const float weight_decay, const DType fp8_dtype,
using namespace at; const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
// Handle bias correction mode // Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f; float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
...@@ -725,10 +719,10 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -725,10 +719,10 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
size_t max_size = 0; size_t max_size = 0;
bool requires_64bit_indexing = false; bool requires_64bit_indexing = false;
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) { for (size_t i = 0; i < num_tensor_lists; i++) {
for (auto it2 = it->begin(); it2 != it->end(); it2++) { for (size_t j = 0; j < num_tensors_per_list; j++) {
if (it2->numel() > max_size) { if (tensor_lists[i][j]->numel() > max_size) {
max_size = it2->numel(); max_size = tensor_lists[i][j]->numel();
if (max_size >= INT_MAX) { if (max_size >= INT_MAX) {
requires_64bit_indexing = true; requires_64bit_indexing = true;
break; break;
...@@ -740,66 +734,147 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -740,66 +734,147 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
} }
} }
const auto g_in_type = tensor_lists[0][0].scalar_type(); const auto g_in_type_te = tensor_lists[0][0]->dtype();
auto tl_size = tensor_lists.size();
// case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv // case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv
TORCH_CHECK(tl_size == 8, "tensor list must contain 8 tensors"); NVTE_CHECK(num_tensor_lists == 8, "tensor list must contain 8 tensors");
if (requires_64bit_indexing) { if (requires_64bit_indexing) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T, fp8_dtype, FP8_T,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type, 0, "adam", g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5, true>( multi_tensor_apply<BLOCK_SIZE, 5, true>(
(int64_t)chunk_size, noop_flag, tensor_lists, (int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, float, int64_t>(), beta1, beta2, AdamFunctorMaster<FP8_T, g_in_type, float, int64_t>(), device_id, stream, beta1,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)); beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
} else { } else {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T, fp8_dtype, FP8_T,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type, 0, "adam", g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5, true>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 5, true>(chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, float, int32_t>(), AdamFunctorMaster<FP8_T, g_in_type, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, device_id, stream, beta1, beta2, bias_correction1,
lr, (adamMode_t)mode, weight_decay);)); bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
} }
AT_CUDA_CHECK(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<Tensor *>> tensor_lists, Tensor lr,
at::Tensor lr, const float beta1, const float beta2, const float beta1, const float beta2, const float epsilon,
const float epsilon, at::Tensor step, const int mode, Tensor step, const int mode, const int bias_correction,
const int bias_correction, const float weight_decay, const float weight_decay, Tensor inv_scale,
at::Tensor inv_scale) { const int device_id, cudaStream_t stream) {
using namespace at; TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam",
multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
AdamCapturableFunctor<scalar_t_0, float>(), beta1, beta2, AdamCapturableFunctor<dtype, float>(), device_id, stream, beta1, beta2,
step.data_ptr<int>(), bias_correction, epsilon, lr.data_ptr<float>(), reinterpret_cast<int *>(step.data.dptr), bias_correction, epsilon,
(adamMode_t)mode, weight_decay, inv_scale.data_ptr<float>());) reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode, weight_decay,
reinterpret_cast<float *>(inv_scale.data.dptr));)
AT_CUDA_CHECK(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<Tensor *>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2, Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode, const float epsilon, Tensor step, const int mode,
const int bias_correction, const float weight_decay, const int bias_correction, const float weight_decay,
at::Tensor inv_scale) { Tensor inv_scale, const int device_id,
using namespace at; cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( tensor_lists[0][0]->dtype(), dtype,
tensor_lists[0][0].scalar_type(), 0, "adam",
multi_tensor_apply<BLOCK_SIZE, 5>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 5>(chunk_size, noop_flag, tensor_lists,
AdamCapturableMasterFunctor<scalar_t_0, float>(), beta1, beta2, AdamCapturableMasterFunctor<dtype, float>(), device_id, stream, beta1,
step.data_ptr<int>(), bias_correction, epsilon, lr.data_ptr<float>(), beta2, reinterpret_cast<int *>(step.data.dptr), bias_correction,
(adamMode_t)mode, weight_decay, inv_scale.data_ptr<float>());) epsilon, reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode,
weight_decay, reinterpret_cast<float *>(inv_scale.data.dptr));)
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_adam
} // namespace transformer_engine
void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
}
void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_param_remainder_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_param_remainder_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
}
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_fp8_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_fp8_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, static_cast<DType>(fp8_dtype), device_id,
stream);
}
void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_capturable_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_capturable_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(lr), beta1, beta2, epsilon, *reinterpret_cast<Tensor *>(step),
mode, bias_correction, weight_decay, *reinterpret_cast<Tensor *>(inv_scale), device_id,
stream);
}
AT_CUDA_CHECK(cudaGetLastError()); void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_capturable_master_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_capturable_master_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(lr), beta1, beta2, epsilon, *reinterpret_cast<Tensor *>(step),
mode, bias_correction, weight_decay, *reinterpret_cast<Tensor *>(inv_scale), device_id,
stream);
} }
...@@ -4,23 +4,21 @@ ...@@ -4,23 +4,21 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <limits> #include <limits>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype. // Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include <sstream> #include <sstream>
#include "common/recipe/recipe_common.cuh" #include "../recipe/recipe_common.cuh"
#include "common/utils.cuh" #include "../utils.cuh"
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_compute_scale {
#define BLOCK_SIZE 256 #define BLOCK_SIZE 256
...@@ -57,12 +55,29 @@ struct ComputeScaleAndScaleInvFunctor { ...@@ -57,12 +55,29 @@ struct ComputeScaleAndScaleInvFunctor {
} }
}; };
void multi_tensor_compute_scale_and_scale_inv_cuda( void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_flag,
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<Tensor *>> tensor_lists,
float max_fp8, bool force_pow_2_scales, float epsilon) { float max_fp8, bool force_pow_2_scales,
using namespace at; float epsilon, const int device_id,
cudaStream_t stream) {
multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists,
ComputeScaleAndScaleInvFunctor(), max_fp8, force_pow_2_scales, epsilon); ComputeScaleAndScaleInvFunctor(), device_id, stream, max_fp8,
AT_CUDA_CHECK(cudaGetLastError()); force_pow_2_scales, epsilon);
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_compute_scale
} // namespace transformer_engine
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_compute_scale_and_scale_inv_cuda);
using namespace transformer_engine;
multi_tensor_compute_scale::multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8,
force_pow_2_scales, epsilon, device_id, stream);
} }
...@@ -4,18 +4,16 @@ ...@@ -4,18 +4,16 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h> #include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "../utils.cuh"
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_l2norm {
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
...@@ -31,6 +29,96 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int s ...@@ -31,6 +29,96 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int s
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; // NOLINT(*) ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; // NOLINT(*)
} }
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes(T *x, T val, int lanes = 1,
bool share_result = false) { // lanes is intended to be <= 32.
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#ifdef __HIP_PLATFORM_AMD__
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down(final, i);
#else
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i);
#endif
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
}
__syncthreads();
// Avoid potential write before read race when reduce_block_into_lanes is called back to back
return final;
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1,
bool share_result = false) { // lanes is intended to be <= 32.
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
#ifdef __HIP_PLATFORM_AMD__
final = fmaxf(fabsf(final), fabsf(__shfl_down(final, i)));
#else
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
#endif
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template <typename x_t> template <typename x_t>
struct L2NormFunctor { struct L2NormFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem, __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem,
...@@ -56,7 +144,7 @@ struct L2NormFunctor { ...@@ -56,7 +144,7 @@ struct L2NormFunctor {
x_t r_x[ILP]; x_t r_x[ILP];
for (int i = 0; i < ILP; i++) { for (int i = 0; i < ILP; i++) {
vals[i] = 0.f; vals[i] = 0.f;
r_x[i] = 0; r_x[i] = 0.f;
} }
// to make things simple, we put aligned case in a different code path // to make things simple, we put aligned case in a different code path
...@@ -126,7 +214,7 @@ struct UnscaleL2NormFunctor { ...@@ -126,7 +214,7 @@ struct UnscaleL2NormFunctor {
x_t r_x[ILP]; x_t r_x[ILP];
for (int i = 0; i < ILP; i++) { for (int i = 0; i < ILP; i++) {
vals[i] = 0.f; vals[i] = 0.f;
r_x[i] = 0; r_x[i] = 0.f;
} }
// to make things simple, we put aligned case in a different code path // to make things simple, we put aligned case in a different code path
...@@ -310,103 +398,96 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret, ...@@ -310,103 +398,96 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
} }
} }
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<Tensor *>> tensor_lists, Tensor output,
at::optional<bool> per_tensor_python) { Tensor output_per_tensor, Tensor ret, Tensor ret_per_tensor,
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; bool per_tensor, int max_chunks_per_tensor, const int device_id,
cudaStream_t stream) {
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
auto output = at::zeros({320}, float_options); tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<BLOCK_SIZE, 1>(
at::Tensor output_per_tensor; chunk_size, noop_flag, tensor_lists, L2NormFunctor<dtype>(), device_id,
at::Tensor ret_per_tensor; stream, reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor,
int ntensors = tensor_lists[0].size(); max_chunks_per_tensor);)
int max_chunks_per_tensor = -1;
NVTE_CHECK_CUDA(cudaGetLastError());
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<BLOCK_SIZE, 1>(chunk_size, noop_flag, tensor_lists,
L2NormFunctor<scalar_t_0>(), output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor,
max_chunks_per_tensor);)
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves one more small kernel launches, but will be negligible end to end. // This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence // I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now // logic, but keeping it simple for now
auto ret = at::empty({1}, output.options()); const OptionalCUDAGuard device_guard(device_id);
const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>(
auto stream = at::cuda::getCurrentCUDAStream(); reinterpret_cast<float *>(output.data.dptr),
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>( per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr,
output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, reinterpret_cast<float *>(ret.data.dptr),
ret.data_ptr<float>(), per_tensor ? ret_per_tensor.data_ptr<float>() : nullptr, per_tensor, per_tensor ? reinterpret_cast<float *>(ret_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor); max_chunks_per_tensor);
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
} }
std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda( void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<Tensor *>> tensor_lists,
at::Tensor inv_scale, at::optional<bool> per_tensor_python) { Tensor output, Tensor output_per_tensor, Tensor ret,
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; Tensor ret_per_tensor, Tensor inv_scale, bool per_tensor,
int max_chunks_per_tensor, const int device_id,
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); cudaStream_t stream) {
auto output = at::zeros({320}, float_options); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
at::Tensor output_per_tensor; multi_tensor_apply<BLOCK_SIZE, 1>(
at::Tensor ret_per_tensor; chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor<dtype>(), device_id,
stream, reinterpret_cast<float *>(inv_scale.data.dptr),
int ntensors = tensor_lists[0].size(); reinterpret_cast<float *>(output.data.dptr),
int max_chunks_per_tensor = -1; per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);)
if (per_tensor) {
for (int t = 0; t < ntensors; t++) { NVTE_CHECK_CUDA(cudaGetLastError());
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_unscale_l2norm_cuda",
multi_tensor_apply<BLOCK_SIZE, 1>(chunk_size, noop_flag, tensor_lists,
UnscaleL2NormFunctor<scalar_t_0>(), inv_scale.data_ptr<float>(),
output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor,
max_chunks_per_tensor);)
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// This involves one more small kernel launches, but will be negligible end to end. // This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence // I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now // logic, but keeping it simple for now
auto ret = at::empty({1}, output.options()); const OptionalCUDAGuard device_guard(device_id);
const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>(
auto stream = at::cuda::getCurrentCUDAStream(); reinterpret_cast<float *>(output.data.dptr),
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>( per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr,
output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, reinterpret_cast<float *>(ret.data.dptr),
ret.data_ptr<float>(), per_tensor ? ret_per_tensor.data_ptr<float>() : nullptr, per_tensor, per_tensor ? reinterpret_cast<float *>(ret_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor); max_chunks_per_tensor);
}
} // namespace multi_tensor_l2norm
} // namespace transformer_engine
void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, int per_tensor,
int max_chunks_per_tensor, const int device_id,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_l2norm_cuda);
using namespace transformer_engine;
multi_tensor_l2norm::multi_tensor_l2norm_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(output), *reinterpret_cast<Tensor *>(output_per_tensor),
*reinterpret_cast<Tensor *>(ret), *reinterpret_cast<Tensor *>(ret_per_tensor), per_tensor,
max_chunks_per_tensor, device_id, stream);
}
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor); void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor output,
NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, NVTETensor inv_scale,
int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_unscale_l2norm_cuda);
using namespace transformer_engine;
multi_tensor_l2norm::multi_tensor_unscale_l2norm_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(output), *reinterpret_cast<Tensor *>(output_per_tensor),
*reinterpret_cast<Tensor *>(ret), *reinterpret_cast<Tensor *>(ret_per_tensor),
*reinterpret_cast<Tensor *>(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream);
} }
...@@ -5,17 +5,62 @@ ...@@ -5,17 +5,62 @@
************************************************************************/ ************************************************************************/
#pragma once #pragma once
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h> #include <assert.h>
#include <c10/cuda/CUDAGuard.h> #include <cuda_runtime.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "common/common.h" #include "../common.h"
// This header is the one-stop shop for all your multi-tensor apply needs. // This header is the one-stop shop for all your multi-tensor apply needs.
// Change device if needed.
class OptionalCUDAGuard {
public:
explicit OptionalCUDAGuard(int new_device) {
if (new_device < 0) return;
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
if (new_device != current_device) {
NVTE_CHECK_CUDA(cudaSetDevice(new_device));
device_changed_ = true;
prev_device_ = current_device;
}
}
OptionalCUDAGuard(const OptionalCUDAGuard &) = delete;
OptionalCUDAGuard &operator=(const OptionalCUDAGuard &) = delete;
OptionalCUDAGuard(OptionalCUDAGuard &&other) noexcept
: prev_device_(other.prev_device_), device_changed_(other.device_changed_) {
other.device_changed_ = false;
}
OptionalCUDAGuard &operator=(OptionalCUDAGuard &&other) noexcept {
if (this != &other) {
if (device_changed_) {
cudaSetDevice(prev_device_);
}
prev_device_ = other.prev_device_;
device_changed_ = other.device_changed_;
other.device_changed_ = false;
}
return *this;
}
~OptionalCUDAGuard() {
if (device_changed_) {
NVTE_CHECK_CUDA(cudaSetDevice(prev_device_));
}
}
private:
int prev_device_;
bool device_changed_ = false;
};
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) // TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24}; constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24};
constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320}; constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320};
...@@ -46,62 +91,40 @@ __global__ void __launch_bounds__(block_size) multi_tensor_apply_kernel(int64_t ...@@ -46,62 +91,40 @@ __global__ void __launch_bounds__(block_size) multi_tensor_apply_kernel(int64_t
} }
template <int64_t block_size, int depth, bool USE_FP8 = false, typename T, typename... ArgTypes> template <int64_t block_size, int depth, bool USE_FP8 = false, typename T, typename... ArgTypes>
void multi_tensor_apply(int64_t chunk_size, const at::Tensor &noop_flag, void multi_tensor_apply(int64_t chunk_size,
const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable, const transformer_engine::Tensor &noop_flag,
ArgTypes... args) { std::vector<std::vector<transformer_engine::Tensor *>> tensor_lists,
if constexpr (USE_FP8) { T callable, const int device_id, cudaStream_t stream, ArgTypes... args) {
TORCH_CHECK(tensor_lists.size() == depth + 3, const size_t num_tensor_lists = tensor_lists.size();
"tensor_lists.size() != depth + 3, tensor_lists should have 3 more tensors (scale, " const size_t num_tensors_per_list = tensor_lists[0].size();
"amax, scale_inv) for fp8");
} else {
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
}
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
for (int l = 0; l < depth; l++) { // No range-based for because I need indices
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for (int t = 0; t < tensor_lists[l].size(); t++) {
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
contiguous_memory =
(contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) ||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d));
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
"A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
}
}
if constexpr (USE_FP8) { if constexpr (USE_FP8) {
TORCH_CHECK(tensor_lists[depth].size() == len0 && tensor_lists[depth + 1].size() == len0, NVTE_CHECK(num_tensor_lists == depth + 3,
"Size mismatch among tensor lists"); "tensor_lists.size() != depth + 3, tensor_lists should have 3 more tensors (scale, "
"amax, scale_inv) for fp8");
} else {
NVTE_CHECK(num_tensor_lists == depth, "tensor_lists.size() != depth");
} }
int ntensors = tensor_lists[0].size();
TensorListMetadata<depth, USE_FP8> tl; TensorListMetadata<depth, USE_FP8> tl;
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); const OptionalCUDAGuard device_guard(device_id);
auto stream = at::cuda::getCurrentCUDAStream();
tl.start_tensor_this_launch = 0; tl.start_tensor_this_launch = 0;
int loc_block_info = 0; int loc_block_info = 0;
int loc_tensor_info = 0; int loc_tensor_info = 0;
auto kernel = &multi_tensor_apply_kernel<block_size, TensorListMetadata<depth, USE_FP8>, T, ArgTypes...>; auto kernel = &multi_tensor_apply_kernel<block_size, TensorListMetadata<depth, USE_FP8>, T, ArgTypes...>;
for (int t = 0; t < ntensors; t++) { for (int t = 0; t < num_tensors_per_list; t++) {
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); tl.sizes[loc_tensor_info] = tensor_lists[0][t]->numel();
for (int d = 0; d < depth; d++) for (int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); tl.addresses[d][loc_tensor_info] = tensor_lists[d][t]->data.dptr;
if constexpr (USE_FP8) { if constexpr (USE_FP8) {
for (int i = 0; i < 3; i++) for (int i = 0; i < 3; i++)
tl.fp8_meta_addresses[i][loc_tensor_info] = tensor_lists[depth + i][t].data_ptr(); tl.fp8_meta_addresses[i][loc_tensor_info] = tensor_lists[depth + i][t]->data.dptr;
} }
loc_tensor_info++; loc_tensor_info++;
auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; auto chunks_this_tensor = (tensor_lists[0][t]->numel() + chunk_size - 1) / chunk_size;
for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) { for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) {
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
...@@ -111,12 +134,12 @@ void multi_tensor_apply(int64_t chunk_size, const at::Tensor &noop_flag, ...@@ -111,12 +134,12 @@ void multi_tensor_apply(int64_t chunk_size, const at::Tensor &noop_flag,
bool tensors_full = bool tensors_full =
(loc_tensor_info == depth_to_max_tensors[depth - 1] && chunk == chunks_this_tensor - 1); (loc_tensor_info == depth_to_max_tensors[depth - 1] && chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); bool last_chunk = (t == num_tensors_per_list - 1 && chunk == chunks_this_tensor - 1);
if (tensors_full || blocks_full || last_chunk) { if (tensors_full || blocks_full || last_chunk) {
kernel<<<loc_block_info, block_size, 0, stream>>>( kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size, noop_flag.data_ptr<int>(), tl, callable, args...); chunk_size, reinterpret_cast<int *>(noop_flag.data.dptr), tl, callable, args...);
AT_CUDA_CHECK(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
// Reset. The control flow possibilities here make my brain hurt. // Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0; loc_block_info = 0;
......
...@@ -4,19 +4,20 @@ ...@@ -4,19 +4,20 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h> #include <assert.h>
#include <cuda_fp8.h>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype. // Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include <iostream>
#include <sstream> #include <sstream>
#include "../utils.cuh"
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_scale {
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
...@@ -66,7 +67,7 @@ struct ScaleFunctor { ...@@ -66,7 +67,7 @@ struct ScaleFunctor {
#pragma unroll #pragma unroll
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale; r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]); finite = finite && isfinite(static_cast<float>(r_in[ii]));
} }
// store // store
load_store(out, r_out, i_start, 0); load_store(out, r_out, i_start, 0);
...@@ -76,7 +77,7 @@ struct ScaleFunctor { ...@@ -76,7 +77,7 @@ struct ScaleFunctor {
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
#pragma unroll #pragma unroll
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
r_in[ii] = 0; r_in[ii] = 0.f;
int i = i_start + threadIdx.x + ii * blockDim.x; int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) r_in[ii] = in[i]; if (i < n && i < chunk_size) r_in[ii] = in[i];
} }
...@@ -88,7 +89,7 @@ struct ScaleFunctor { ...@@ -88,7 +89,7 @@ struct ScaleFunctor {
#pragma unroll #pragma unroll
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale; r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]); finite = finite && isfinite(static_cast<float>(r_in[ii]));
} }
#pragma unroll #pragma unroll
for (int ii = 0; ii < ILP; ii++) { for (int ii = 0; ii < ILP; ii++) {
...@@ -101,20 +102,29 @@ struct ScaleFunctor { ...@@ -101,20 +102,29 @@ struct ScaleFunctor {
} }
}; };
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float scale) { std::vector<std::vector<Tensor *>> tensor_lists, float scale,
using namespace at; const int device_id, cudaStream_t stream) {
// The output (downscaled) type is always float. TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
// If build times suffer, think about where to put this dispatch, tensor_lists[0][0]->dtype(), p_in_type,
// and what logic should be moved out of multi_tensor_apply. TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[1][0]->dtype(), g_in_type,
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
multi_tensor_apply<BLOCK_SIZE, 2>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 2>(chunk_size, noop_flag, tensor_lists,
ScaleFunctor<scalar_t_0, scalar_t_1>(), scale);)) ScaleFunctor<p_in_type, g_in_type>(), device_id, stream, scale);))
AT_CUDA_CHECK(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_scale
} // namespace transformer_engine
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_scale_cuda);
using namespace transformer_engine;
// AT_CUDA_CHECK(cudaDeviceSynchronize()); multi_tensor_scale::multi_tensor_scale_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, device_id,
stream);
} }
...@@ -4,14 +4,16 @@ ...@@ -4,14 +4,16 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h> #include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "../utils.cuh"
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_sgd {
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
...@@ -54,9 +56,9 @@ struct SGDFunctor { ...@@ -54,9 +56,9 @@ struct SGDFunctor {
T_weight* mom_in = reinterpret_cast<T_weight*>(tl.addresses[2][tensor_loc]); T_weight* mom_in = reinterpret_cast<T_weight*>(tl.addresses[2][tensor_loc]);
mom_in += chunk_idx * chunk_size; mom_in += chunk_idx * chunk_size;
at::Half* model_weights_out = nullptr; fp16* model_weights_out = nullptr;
if (N == 4) { if (N == 4) {
model_weights_out = (at::Half*)tl.addresses[3][tensor_loc]; model_weights_out = reinterpret_cast<fp16*>(tl.addresses[3][tensor_loc]);
model_weights_out += chunk_idx * chunk_size; model_weights_out += chunk_idx * chunk_size;
} }
...@@ -112,7 +114,7 @@ struct SGDFunctor { ...@@ -112,7 +114,7 @@ struct SGDFunctor {
weight_in[i] += (-lr * incoming_grads[ii]); weight_in[i] += (-lr * incoming_grads[ii]);
// if necessary, write out an fp16 copy of the weights // if necessary, write out an fp16 copy of the weights
if (N == 4) model_weights_out[i] = static_cast<at::Half>(weight_in[i]); if (N == 4) model_weights_out[i] = static_cast<fp16>(weight_in[i]);
// also write out the new momentum // also write out the new momentum
if (momentum != 0.f) mom_in[i] = incoming_moms[ii]; if (momentum != 0.f) mom_in[i] = incoming_moms[ii];
...@@ -122,23 +124,23 @@ struct SGDFunctor { ...@@ -122,23 +124,23 @@ struct SGDFunctor {
} }
}; };
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float wd, std::vector<std::vector<Tensor*>> tensor_lists, float wd, float momentum,
float momentum, float dampening, float lr, bool nesterov, bool first_run, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale) { bool wd_after_momentum, float scale, const int device_id,
auto num_tensors = tensor_lists.size(); cudaStream_t stream) {
auto grad_type = tensor_lists[0][0].scalar_type(); const size_t num_tensor_lists = tensor_lists.size();
auto weight_type = tensor_lists[1][0].scalar_type(); const size_t num_tensors_per_list = tensor_lists[0].size();
if (num_tensors == 4) { auto grad_type = tensor_lists[0][0]->dtype();
for (int i = 0; i < tensor_lists[3].size(); i++) auto weight_type = tensor_lists[1][0]->dtype();
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
"Additional output tensors should always be fp16."); if (num_tensor_lists == 4) {
for (int i = 0; i < num_tensors_per_list; i++)
NVTE_CHECK(tensor_lists[3][i]->dtype() == DType::kFloat16,
"Additional output tensors should always be fp16.");
} }
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(),
"expected noop flag to be on the same device as tensors");
// We have 3 possibilities to handle here, in terms of // We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy // grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No // 1. fp16, fp16, fp16, No
...@@ -150,53 +152,51 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, ...@@ -150,53 +152,51 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
// we don't want the majority of them. // we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No // Case 1. fp16, fp16, fp16, No
if (grad_type == at::ScalarType::Half && weight_type == at::ScalarType::Half && if (grad_type == DType::kFloat16 && weight_type == DType::kFloat16 && num_tensor_lists == 3) {
num_tensors == 3) {
multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, at::Half, at::Half>(), wd, momentum, dampening, lr, SGDFunctor<3, fp16, fp16>(), device_id, stream, wd, momentum, dampening,
nesterov, first_run, wd_after_momentum, scale); lr, nesterov, first_run, wd_after_momentum, scale);
} }
// Case 2. fp16, fp32, fp32, No
// else if (grad_type == at::ScalarType::Half &&
// weight_type == at::ScalarType::Float &&
// num_tensors == 3) {
// multi_tensor_apply<BLOCK_SIZE, 3>(
// chunk_size,
// noop_flag,
// tensor_lists,
// SGDFunctor<3, at::Half, float>(),
// wd,
// momentum,
// dampening,
// lr,
// nesterov,
// first_run,
// wd_after_momentum);
// }
// Case 2. fp32, fp32, fp32, No // Case 2. fp32, fp32, fp32, No
else if (grad_type == at::ScalarType::Float && // NOLINT(*) else if (grad_type == DType::kFloat32 && // NOLINT(*)
weight_type == at::ScalarType::Float && num_tensors == 3) { weight_type == DType::kFloat32 && num_tensor_lists == 3) {
multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, float, float>(), wd, momentum, dampening, lr, nesterov, SGDFunctor<3, float, float>(), device_id, stream, wd, momentum, dampening,
first_run, wd_after_momentum, scale); lr, nesterov, first_run, wd_after_momentum, scale);
} }
// Case 3. fp16, fp32, fp32, Yes // Case 3. fp16, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Half && // NOLINT(*) else if (grad_type == DType::kFloat16 && // NOLINT(*)
weight_type == at::ScalarType::Float && num_tensors == 4) { weight_type == DType::kFloat32 && num_tensor_lists == 4) {
multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, at::Half, float>(), wd, momentum, dampening, lr, nesterov, SGDFunctor<4, fp16, float>(), device_id, stream, wd, momentum, dampening,
first_run, wd_after_momentum, scale); lr, nesterov, first_run, wd_after_momentum, scale);
} }
// Case 4. fp32, fp32, fp32, Yes // Case 4. fp32, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Float && // NOLINT(*) else if (grad_type == DType::kFloat32 && // NOLINT(*)
weight_type == at::ScalarType::Float && num_tensors == 4) { weight_type == DType::kFloat32 && num_tensor_lists == 4) {
multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists, multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, float, float>(), wd, momentum, dampening, lr, nesterov, SGDFunctor<4, float, float>(), device_id, stream, wd, momentum, dampening,
first_run, wd_after_momentum, scale); lr, nesterov, first_run, wd_after_momentum, scale);
} else { } else {
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ", NVTE_ERROR("Unsupported combination of weight and gradient types.");
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
} }
AT_CUDA_CHECK(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_sgd
} // namespace transformer_engine
void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor** tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_sgd_cuda);
using namespace transformer_engine;
multi_tensor_sgd::multi_tensor_sgd_cuda(
chunk_size, *reinterpret_cast<Tensor*>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum, scale, device_id, stream);
} }
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include <transformer_engine/permutation.h> #include <transformer_engine/permutation.h>
#include <cub/cub.cuh>
#include "../common.h" #include "../common.h"
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
...@@ -385,3 +387,11 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id ...@@ -385,3 +387,11 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id
reinterpret_cast<const float *>(prob_cu->data.dptr), num_rows, topK, reinterpret_cast<const float *>(prob_cu->data.dptr), num_rows, topK,
num_cols, stream);); num_cols, stream););
} }
void nvte_device_radix_sort_pairs(void *temp_storage, size_t *temp_storage_bytes, int *keys_in,
int *keys_out, int *values_in, int *values_out,
size_t num_items) {
NVTE_API_CALL(nvte_device_radix_sort_pairs);
cub::DeviceRadixSort::SortPairs(temp_storage, *temp_storage_bytes, keys_in, keys_out, values_in,
values_out, num_items);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/recipe.h>
#include <cassert>
#include "../common.h"
#include "../utils.cuh"
namespace transformer_engine {
namespace fp8_block_scaling_recipe {
constexpr int kTileDim = 128;
constexpr int kThreadsPerBlock = 256;
template <typename IType>
__global__ void __launch_bounds__(kThreadsPerBlock)
fp8_block_scaling_compute_partial_amax_kernel(const IType *input, float *amax_ptr,
const size_t amax_stride_h,
const size_t amax_stride_w, const size_t h,
const size_t w, const size_t start_offset,
const size_t len) {
constexpr int kThreadsPerWarp = 32;
constexpr int kLoopsPerRow = kTileDim / kThreadsPerWarp;
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;
constexpr int kLoopsPerCol = kTileDim / kNumWarps;
const int tile_col = blockIdx.x;
const int tile_row = blockIdx.y;
const size_t end_offset = start_offset + len;
const IType *input_minus_offset = input - start_offset;
__shared__ float smem[kNumWarps];
float amax = 0.0f;
for (int loop_col = 0; loop_col < kLoopsPerCol; ++loop_col) {
size_t r = tile_row * kTileDim + loop_col * kNumWarps + threadIdx.x / kThreadsPerWarp;
for (int loop_row = 0; loop_row < kLoopsPerRow; ++loop_row) {
size_t c = tile_col * kTileDim + loop_row * kThreadsPerWarp + (threadIdx.x % kThreadsPerWarp);
size_t idx = r * w + c;
if (r < h && c < w && idx >= start_offset && idx < end_offset) {
float other_amax = fabs(static_cast<float>(input_minus_offset[idx]));
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
}
}
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
float other_amax = __shfl_down_sync(0xFFFFFFFF, amax, delta);
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
if (threadIdx.x % kThreadsPerWarp == 0) {
smem[threadIdx.x / kThreadsPerWarp] = amax;
}
__syncthreads();
if (threadIdx.x == 0) {
for (int i = 0; i < kNumWarps; ++i) {
float other_amax = smem[i];
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
amax_ptr[tile_row * amax_stride_h + tile_col * amax_stride_w] = amax;
}
}
template <typename IType, typename OType, bool kWidthAligned>
__global__ void __launch_bounds__(kThreadsPerBlock)
fp8_block_scaling_partial_cast_kernel(const IType *input, OType *output, const float *scale_ptr,
const size_t scale_stride_h, const size_t scale_stride_w,
const size_t h, const size_t w, const size_t start_offset,
const size_t len) {
using transformer_engine::Vec;
static_assert(sizeof(OType) == 1);
constexpr int kNumOutputElemsPerBank = 4 / sizeof(OType);
constexpr int kThreadsPerWarp = 32;
constexpr int kLoopsPerRow = kTileDim / kThreadsPerWarp;
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;
constexpr int kRowsPerWarp = kTileDim / kNumWarps;
__shared__ OType smem[kTileDim][kTileDim + kNumOutputElemsPerBank];
const int tile_w = blockIdx.x;
const int tile_h = blockIdx.y;
const size_t end_offset = start_offset + len;
const IType *input_minus_offset = input - start_offset;
OType *output_minus_offset = output - start_offset;
const float scale = scale_ptr[tile_h * scale_stride_h + tile_w * scale_stride_w];
// Load input data into shared memory
bool skip_store = true;
for (int i = 0; i < kRowsPerWarp; ++i) {
for (int j = 0; j < kLoopsPerRow; ++j) {
const int h_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i;
const int w_in_smem = threadIdx.x % kThreadsPerWarp + kThreadsPerWarp * j;
const int h_in_input = tile_h * kTileDim + h_in_smem;
const int w_in_input = tile_w * kTileDim + w_in_smem;
const size_t idx_in_input = static_cast<size_t>(h_in_input) * w + w_in_input;
if (h_in_input < h && w_in_input < w && idx_in_input >= start_offset &&
idx_in_input < end_offset) {
float inp = static_cast<float>(input_minus_offset[idx_in_input]) * scale;
smem[h_in_smem][w_in_smem] = static_cast<OType>(inp);
skip_store = false;
}
}
}
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
bool other_skip_store = __shfl_down_sync(0xFFFFFFFF, skip_store, delta);
skip_store = skip_store && other_skip_store;
}
skip_store = __shfl_sync(0xFFFFFFFF, skip_store, 0);
if (skip_store) {
return;
}
// Store the casted data into the output.
// Note that this store operation might write "out-of-bounds", but it is intentional:
// 1. The "out-of-bounds" here only crosses the boundary of the "local shard" (i.e., the region
// from start_offset to end_offset), not the boundary of the entire output memory. Therefore,
// this out-of-bounds write will not cause illegal memory access.
// 2. We assume that the subsequent all-gather operation happens in-place, so any parts that
// should not be updated here will be overwritten by the all-gather.
// This tricky approach allows us to avoid checking whether each output index falls within
// [start, end), resulting in a significant performance improvement.
Vec<OType, kNumOutputElemsPerBank> vec_output;
for (int i = 0; i < kRowsPerWarp; ++i) {
const int row_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i;
const int col_in_smem = threadIdx.x % kThreadsPerWarp * kNumOutputElemsPerBank;
for (int j = 0; j < kNumOutputElemsPerBank; ++j) {
vec_output.data.elt[j] = smem[row_in_smem][col_in_smem + j];
}
const int row_in_output = tile_h * kTileDim + row_in_smem;
const int col_in_output = tile_w * kTileDim + col_in_smem;
const size_t idx_in_output = static_cast<size_t>(row_in_output) * w + col_in_output;
if (row_in_output < h) {
if constexpr (kWidthAligned) {
vec_output.store_to(output_minus_offset + idx_in_output);
} else {
int num = min(static_cast<size_t>(kNumOutputElemsPerBank),
static_cast<size_t>(col_in_output < w ? w - col_in_output : 0));
vec_output.store_to_elts(output_minus_offset, idx_in_output, num);
}
}
}
}
void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_t h, size_t w,
size_t amax_stride_h, size_t amax_stride_w,
size_t start_offset, size_t block_len,
cudaStream_t stream) {
NVTE_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
size_t len = inp.numel();
assert(h > 0 && w > 0);
assert(start_offset < h * w);
assert(start_offset + len <= h * w);
size_t blocks_x = (w + kTileDim - 1) / kTileDim;
size_t blocks_y = (h + kTileDim - 1) / kTileDim;
assert(blocks_x <= std::numeric_limits<unsigned int>::max());
assert(blocks_y <= std::numeric_limits<unsigned int>::max());
dim3 grid(blocks_x, blocks_y);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
inp.dtype(), inp_dtype,
fp8_block_scaling_compute_partial_amax_kernel<inp_dtype>
<<<grid, kThreadsPerBlock, 0, stream>>>(reinterpret_cast<const inp_dtype *>(inp.data.dptr),
reinterpret_cast<float *>(amax.data.dptr),
amax_stride_h, amax_stride_w, h, w, start_offset,
len);)
}
void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor scale, size_t h,
size_t w, size_t scale_stride_h, size_t scale_stride_w,
size_t start_offset, size_t block_len, const DType out_dtype,
cudaStream_t stream) {
NVTE_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
size_t len = inp.numel();
assert(h > 0 && w > 0);
assert(start_offset < h * w);
assert(start_offset + len <= h * w);
size_t blocks_x = (w + kTileDim - 1) / kTileDim;
size_t blocks_y = (h + kTileDim - 1) / kTileDim;
assert(blocks_x <= std::numeric_limits<unsigned int>::max());
assert(blocks_y <= std::numeric_limits<unsigned int>::max());
dim3 grid(blocks_x, blocks_y);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
inp.dtype(), inp_dtype,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
out_dtype, fp8_type,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
w % kTileDim == 0, kWidthAligned,
fp8_block_scaling_partial_cast_kernel<inp_dtype, fp8_type, kWidthAligned>
<<<grid, kThreadsPerBlock, 0, stream>>>(
reinterpret_cast<const inp_dtype *>(inp.data.dptr),
reinterpret_cast<fp8_type *>(out.data.dptr),
reinterpret_cast<const float *>(scale.data.dptr), scale_stride_h, scale_stride_w,
h, w, start_offset, len);)))
}
} // namespace fp8_block_scaling_recipe
} // namespace transformer_engine
void nvte_fp8_block_scaling_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h,
size_t w, size_t amax_stride_h,
size_t amax_stride_w, size_t start_offset,
size_t block_len, cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_block_scaling_compute_partial_amax);
using namespace transformer_engine;
fp8_block_scaling_recipe::fp8_block_scaling_compute_partial_amax(
*reinterpret_cast<const Tensor *>(inp), *reinterpret_cast<Tensor *>(amax), h, w,
amax_stride_h, amax_stride_w, start_offset, block_len, stream);
}
void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
const NVTETensor scale, size_t h, size_t w,
size_t scale_stride_h, size_t scale_stride_w,
size_t start_offset, size_t block_len,
const NVTEDType out_dtype, cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_block_scaling_partial_cast);
using namespace transformer_engine;
fp8_block_scaling_recipe::fp8_block_scaling_partial_cast(
*reinterpret_cast<const Tensor *>(inp), *reinterpret_cast<Tensor *>(out),
*reinterpret_cast<const Tensor *>(scale), h, w, scale_stride_h, scale_stride_w, start_offset,
block_len, static_cast<DType>(out_dtype), stream);
}
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <iostream> #include <iostream>
#include "common.h" #include "common.h"
#include "common/util/cuda_runtime.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -48,11 +49,11 @@ std::string to_string(const DType type) { ...@@ -48,11 +49,11 @@ std::string to_string(const DType type) {
std::string to_string(const NVTEScalingMode &mode) { std::string to_string(const NVTEScalingMode &mode) {
switch (mode) { switch (mode) {
case NVTE_DELAYED_TENSOR_SCALING: case NVTE_DELAYED_TENSOR_SCALING:
return "Delayed Tensor Scaling"; return "NVTE_DELAYED_TENSOR_SCALING";
case NVTE_MXFP8_1D_SCALING: case NVTE_MXFP8_1D_SCALING:
return "MXFP8 1D Scaling"; return "NVTE_MXFP8_1D_SCALING";
case NVTE_INVALID_SCALING: case NVTE_INVALID_SCALING:
return "Invalid Scaling"; return "NVTE_INVALID_SCALING";
} }
return "Invalid Scaling"; return "Invalid Scaling";
} }
...@@ -214,15 +215,13 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) { ...@@ -214,15 +215,13 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) {
NVTEShape nvte_make_shape(const size_t *data, size_t ndim) { NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
NVTEShape ret; NVTEShape ret;
if (ndim == 0) { if (ndim == 0) {
ret.data = nullptr;
ret.ndim = 0; ret.ndim = 0;
return ret; return ret;
} }
NVTE_CHECK(ndim <= sizeof(ret.owned_data) / sizeof(ret.owned_data[0]), NVTE_CHECK(ndim <= sizeof(ret.data) / sizeof(ret.data[0]),
"Too many dims for NVTEShape (requested: ", ndim, "Too many dims for NVTEShape (requested: ", ndim,
", max: ", sizeof(ret.owned_data) / sizeof(ret.owned_data[0]), ")"); ", max: ", sizeof(ret.data) / sizeof(ret.data[0]), ")");
std::copy(data, data + ndim, ret.owned_data); std::copy(data, data + ndim, ret.data);
ret.data = ret.owned_data;
ret.ndim = ndim; ret.ndim = ndim;
return ret; return ret;
} }
...@@ -350,7 +349,7 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, ...@@ -350,7 +349,7 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) { NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) {
if (tensor == nullptr) { if (tensor == nullptr) {
return {nullptr, kNVTEFloat32, {nullptr, 0}}; return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)};
} }
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor); const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
switch (param_name) { switch (param_name) {
...@@ -483,3 +482,13 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { ...@@ -483,3 +482,13 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
delete reinterpret_cast<transformer_engine::QuantizationConfig *>(config); delete reinterpret_cast<transformer_engine::QuantizationConfig *>(config);
} }
} }
int nvte_is_non_tn_fp8_gemm_supported() {
int deviceComputeCapability =
transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device());
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
return (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
deviceComputeCapability >= 130;
}
...@@ -134,9 +134,15 @@ bool supports_multicast(int device_id) { ...@@ -134,9 +134,15 @@ bool supports_multicast(int device_id) {
auto init = [&]() { auto init = [&]() {
CUdevice cudev; CUdevice cudev;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id); NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id);
int result; // Multicast support requires both CUDA12.1 UMD + KMD
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result, int result = 0;
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev); // Check if KMD >= 12.1
int driver_version;
NVTE_CHECK_CUDA(cudaDriverGetVersion(&driver_version));
if (driver_version >= 12010) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev);
}
cache[device_id] = static_cast<bool>(result); cache[device_id] = static_cast<bool>(result);
}; };
std::call_once(flags[device_id], init); std::call_once(flags[device_id], init);
......
...@@ -23,10 +23,18 @@ ...@@ -23,10 +23,18 @@
#endif // __HIP_PLATFORM_AMD__ #endif // __HIP_PLATFORM_AMD__
#include <nvrtc.h> #include <nvrtc.h>
#include <iostream>
#include <stdexcept> #include <stdexcept>
#include "../util/string.h" #include "../util/string.h"
#define NVTE_WARN(...) \
do { \
std::cerr << ::transformer_engine::concat_strings( \
__FILE__ ":", __LINE__, " in function ", __func__, ": ", \
::transformer_engine::concat_strings(__VA_ARGS__), "\n"); \
} while (false)
#define NVTE_ERROR(...) \ #define NVTE_ERROR(...) \
do { \ do { \
throw ::std::runtime_error(::transformer_engine::concat_strings( \ throw ::std::runtime_error(::transformer_engine::concat_strings( \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment