Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
......@@ -562,5 +562,75 @@ size_t get_max_tokens(size_t num_tokens) {
return max_t;
}
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
int64_t offset) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid > 0) return;
rng_state_dst[0] = seed[0];
rng_state_dst[1] = offset;
}
__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= len) return;
if (cu_seqlen[tid] > 0) {
// atomicAdd only support 32 bits dtype
atomicAdd(out, 1);
}
}
void PopulateRngStateAsync(void *rng_state_dst, const void *seed, size_t q_max_seqlen,
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream) {
size_t increment = 0;
if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
increment = 16;
} else {
constexpr int threads_per_cta = 128;
increment = (q_max_seqlen * kv_max_seqlen + threads_per_cta - 1) / threads_per_cta;
}
auto offset = FusedAttnOffsetManager::Instance().GetAndUpdateOffset(increment);
populate_rng_state_kernel<<<1, 1, 0, stream>>>(reinterpret_cast<int64_t *>(rng_state_dst),
reinterpret_cast<const int64_t *>(seed), offset);
NVTE_CHECK_CUDA(cudaGetLastError());
}
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream) {
// workspace size requires 4 bytes
uint32_t *dout = static_cast<uint32_t *>(workspace);
uint32_t hout{};
cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream);
constexpr int threads = 128;
const int blocks = (len - 1) / threads + 1;
get_runtime_num_segments_kernel<<<blocks, threads, 0, stream>>>(static_cast<int32_t *>(cu_seqlen),
len, dout);
cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
return hout;
}
__global__ void extract_seed_and_offset(int64_t *rng_state_ptr, bool captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph) {
if (captured) {
rng_state_ptr[0] = *seed_ptr;
rng_state_ptr[1] = static_cast<int64_t>(*offset_ptr + static_cast<int64_t>(offset_intragraph));
} else {
rng_state_ptr[0] = static_cast<int64_t>(seed_val);
rng_state_ptr[1] = static_cast<int64_t>(offset_val);
}
}
} // namespace fused_attn
} // namespace transformer_engine
void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph, cudaStream_t stream) {
NVTE_API_CALL(nvte_extract_seed_and_offset);
using namespace transformer_engine;
fused_attn::extract_seed_and_offset<<<1, 1, 0, stream>>>(
rng_state_ptr, captured, seed_ptr, seed_val, offset_ptr, offset_val, offset_intragraph);
}
......@@ -150,6 +150,38 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at
size_t get_max_batch_size(size_t batch_size);
size_t get_max_tokens(size_t num_tokens);
class FusedAttnOffsetManager {
public:
static FusedAttnOffsetManager &Instance() {
static thread_local FusedAttnOffsetManager instance;
return instance;
}
size_t GetAndUpdateOffset(size_t increment) {
size_t ret = offset_;
offset_ += increment;
return ret;
}
FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete;
void operator=(FusedAttnOffsetManager const &) = delete;
private:
FusedAttnOffsetManager() {}
size_t offset_ = 0;
};
__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
int64_t offset);
__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out);
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
cudaStream_t stream);
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);
} // namespace fused_attn
} // namespace transformer_engine
......
......@@ -115,10 +115,10 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
template <typename scalar_t>
__global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seqlens,
const float *freqs, scalar_t *dst, const bool interleaved,
const int cp_size, const int cp_rank, const int s,
const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b,
const float *freqs, const int *start_positions,
scalar_t *dst, const bool interleaved, const int cp_size,
const int cp_rank, const int s, const int h, const int d,
const int d2, const int stride_s_or_t, const int stride_b,
const int stride_h, const int stride_d,
const int o_stride_s_or_t, const int o_stride_b,
const int o_stride_h, const int o_stride_d) {
......@@ -149,7 +149,8 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
}
} else {
s_id_for_freqs = s_id;
int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id];
s_id_for_freqs = s_id + begin_offset;
}
fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block,
......@@ -199,11 +200,12 @@ __global__ void fused_rope_backward_kernel(
template <typename scalar_t>
void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs,
scalar_t *output, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, cudaStream_t stream) {
const int *start_positions, scalar_t *output,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
......@@ -223,8 +225,9 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c
const int o_stride_d = 1;
fused_rope_forward_kernel<<<blocks, threads, 0, stream>>>(
input, cu_seqlens, freqs, output, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d);
input, cu_seqlens, freqs, start_positions, output, interleaved, cp_size, cp_rank, s, h, d, d2,
stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h,
o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -262,15 +265,17 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
}
void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs,
Tensor *output, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b, const int h,
const int d, const int d2, const int stride_s_or_t, const int stride_b,
const Tensor &start_positions, Tensor *output,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size,
const int cp_rank, const int s, const int b, const int h, const int d,
const int d2, const int stride_s_or_t, const int stride_b,
const int stride_h, const int stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t,
fused_rope_forward_launcher(reinterpret_cast<const scalar_t *>(input.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<const int *>(start_positions.data.dptr),
reinterpret_cast<scalar_t *>(output->data.dptr), qkv_format,
interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream););
......@@ -295,19 +300,19 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c
} // end namespace transformer_engine
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
const NVTETensor freqs, const NVTETensor start_positions,
NVTETensor output, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_forward);
using namespace transformer_engine;
fused_rope_forward(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs), reinterpret_cast<Tensor *>(output),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream);
fused_rope_forward(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs), *reinterpret_cast<const Tensor *>(start_positions),
reinterpret_cast<Tensor *>(output), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2,
stride_s_or_t, stride_b, stride_h, stride_d, stream);
}
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
......
......@@ -11,8 +11,7 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_
#include <cstdint>
#include "stdint.h"
#include "transformer_engine.h"
#ifdef __cplusplus
......@@ -245,7 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack* Aux_CTX_Tensors,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
......@@ -301,7 +300,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
*/
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQKV,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
......@@ -369,7 +368,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
*/
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
......@@ -430,7 +429,7 @@ void nvte_fused_attn_fwd_kvpacked(
*/
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
......@@ -501,7 +500,7 @@ void nvte_fused_attn_bwd_kvpacked(
*/
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
......@@ -570,7 +569,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
*/
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack* Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
......@@ -580,6 +579,76 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Update the RNG state with the seed and calculated offset.
*
* \param[in] rng_state_dst RNG state to store seed and offset.
* \param[in] seed Seed for RNG state.
* \param[in] q_max_seqlen Max sequence length used for computing for Q.
* it may be >= max(seqlen_q_i) for i=0,...batch_size-1.
* \param[in] kv_max_seqlen Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] backend Fused attention backend.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor seed,
size_t q_max_seqlen, size_t kv_max_seqlen,
NVTE_Fused_Attn_Backend backend, cudaStream_t stream);
/*! \brief Get KV format for a given QKV layout.
*
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] workspace Workspace tensor.
* \param[in] len batch_size x sequence_length.
* \param[in] stream CUDA stream used for this operation.
*/
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len,
cudaStream_t stream);
void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph, cudaStream_t stream);
void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache,
NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens,
NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b,
int max_ctx_len, int max_seq_len, int max_pages_per_seq,
int is_non_paged, cudaStream_t stream);
void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu_seqlens,
NVTETensor half, int half_idx, cudaStream_t stream);
void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int lse_packed,
cudaStream_t stream);
void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens,
NVTETensor half_lse, int lse_packed,
int second_half_lse_seqlen, cudaStream_t stream);
void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step,
const NVTETensor &lse, const NVTETensor &lse_per_step,
const NVTETensor &cu_seqlens, int only_second_half, int lse_packed,
cudaStream_t stream);
void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step,
const NVTETensor &cu_seqlens, const char *first_half,
const char *second_half, cudaStream_t stream);
void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output,
int total_tokens, int world_size, int rank,
cudaStream_t stream);
void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int b, int max_seq_len, cudaStream_t stream);
void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
int t, cudaStream_t stream);
void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream);
void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -20,6 +20,7 @@ extern "C" {
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor.
* \param[in] start_positions The beginning offsets for applying RoPE embeddings.
* \param[out] output Output tensor.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
......@@ -37,12 +38,12 @@ extern "C" {
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream);
const NVTETensor freqs, const NVTETensor start_positions,
NVTETensor output, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, cudaStream_t stream);
/*! \brief Compute the backward of the fused rope.
*
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file multi_tensor.h
* \brief Functions handling multi tensor kernels.
*/
#ifndef TRANSFORMER_ENGINE_MULTI_TENSOR_H_
#define TRANSFORMER_ENGINE_MULTI_TENSOR_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, int per_tensor,
int max_chunks_per_tensor, const int device_id,
cudaStream_t stream);
void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor output,
NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, NVTETensor inv_scale,
int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream);
void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream);
void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream);
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream);
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon,
const int device_id, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_MULTI_TENSOR_H_
......@@ -18,4 +18,7 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id
const NVTETensor prob, const int num_rows, const int topK, const int num_cols,
cudaStream_t stream = nullptr);
void nvte_device_radix_sort_pairs(void *temp_storage, size_t *temp_storage_bytes, int *keys_in,
int *keys_out, int *values_in, int *values_out, size_t num_items);
#endif // TRANSFORMER_ENGINE_PERMUTATION_H_
......@@ -96,6 +96,17 @@ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t s
void nvte_compute_scale_from_amax(NVTETensor output, const NVTEQuantizationConfig config,
cudaStream_t stream);
void nvte_fp8_block_scaling_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h,
size_t w, size_t amax_stride_h,
size_t amax_stride_w, size_t start_offset,
size_t block_len, cudaStream_t stream);
void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
const NVTETensor scale, size_t h, size_t w,
size_t scale_stride_h, size_t scale_stride_w,
size_t start_offset, size_t block_len,
const NVTEDType out_dtype, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -23,14 +23,15 @@ extern "C" {
*/
enum NVTEDType {
kNVTEByte = 0, /*!< Byte */
kNVTEInt32 = 1, /*!< 32-bit integer */
kNVTEInt64 = 2, /*!< 64-bit integer */
kNVTEFloat32 = 3, /*!< 32-bit float */
kNVTEFloat16 = 4, /*!< 16-bit float (E5M10) */
kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0 = 8, /*!< 8-bit float (E8M0) */
kNVTEInt16 = 1, /*!< 16-bit integer */
kNVTEInt32 = 2, /*!< 32-bit integer */
kNVTEInt64 = 3, /*!< 64-bit integer */
kNVTEFloat32 = 4, /*!< 32-bit float */
kNVTEFloat16 = 5, /*!< 16-bit float (E5M10) */
kNVTEBFloat16 = 6, /*!< 16-bit bfloat (E8M7) */
kNVTEFloat8E4M3 = 7, /*!< 8-bit float (E4M3) */
kNVTEFloat8E5M2 = 8, /*!< 8-bit float (E5M2) */
kNVTEFloat8E8M0 = 9, /*!< 8-bit float (E8M0) */
kNVTENumTypes /*!< Number of supported types */
};
......@@ -38,12 +39,10 @@ enum NVTEDType {
* \brief Shape of the tensor.
*/
struct NVTEShape {
/*! \brief Shape data, of size ndim. */
const size_t *data;
/*! \brief Shape data, with ndim valid elements. */
size_t data[15];
/*! \brief Number of dimensions. */
size_t ndim;
/*! \brief Copy of data. Num dims limited to permit fixed struct size.*/
size_t owned_data[14];
};
/*! \struct NVTEBasicTensor
......@@ -343,6 +342,23 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
*/
void nvte_destroy_quantization_config(NVTEQuantizationConfig config);
/*! \brief Check if non-TN FP8 Gemm is supported.
*
* \return A flag which indicates whether non-TN FP8 Gemm is supported or not.
*/
int nvte_is_non_tn_fp8_gemm_supported();
/*! \brief Performs a memset of the data at the given pointer and size in bytes.
*
* \param[in] ptr Pointer to the memory to be set.
* \param[in] value Value to set the memory to.
* \param[in] size_in_bytes Size of the memory in bytes.
* \param[in] stream CUDA stream to use for the operation.
*
* This function calls a fill kernel for small sizes and calls cudaMemsetAsync for larger sizes.
*/
void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......@@ -358,14 +374,15 @@ namespace transformer_engine {
*/
enum class DType {
kByte = 0,
kInt32 = 1,
kInt64 = 2,
kFloat32 = 3,
kFloat16 = 4,
kBFloat16 = 5,
kFloat8E4M3 = 6,
kFloat8E5M2 = 7,
kFloat8E8M0 = 8,
kInt16 = 1,
kInt32 = 2,
kInt64 = 3,
kFloat32 = 4,
kFloat16 = 5,
kBFloat16 = 6,
kFloat8E4M3 = 7,
kFloat8E5M2 = 8,
kFloat8E8M0 = 9,
kNumTypes
};
......@@ -691,15 +708,10 @@ class TensorWrapper {
static constexpr size_t defaultData = 1;
static constexpr NVTEShape defaultShape = {
&defaultData, 1, {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
{defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1};
private:
NVTEShape convertShape(const NVTEShape &s) {
NVTEShape ret = s;
// Move the ownership rather than pointing to the parent shape.
ret.data = ret.owned_data;
return ret;
}
NVTEShape convertShape(const NVTEShape &s) { return s; }
NVTEShape convertShape(const std::vector<size_t> &s) {
return nvte_make_shape(s.data(), s.size());
......
......@@ -4,23 +4,16 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#ifdef __HIP_PLATFORM_AMD__
#include "amd_detail/hip_float8.h"
#else
#include <cuda_fp8.h>
#endif
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "common/utils.cuh"
#include "../utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_adam {
#define BLOCK_SIZE 512
#define ILP 4
......@@ -39,7 +32,6 @@ using fp8e5m2 = __nv_fp8_e5m2;
using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = te_hip_fp8_e5m2;
#endif
using transformer_engine::DType;
template <typename T>
struct is_fp8 : std::false_type {};
......@@ -585,12 +577,13 @@ struct AdamCapturableMasterFunctor {
}
};
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay) {
using namespace at;
const float weight_decay, const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
......@@ -601,10 +594,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
size_t max_size = 0;
bool requires_64bit_indexing = false;
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
for (auto it2 = it->begin(); it2 != it->end(); it2++) {
if (it2->numel() > max_size) {
max_size = it2->numel();
for (size_t i = 0; i < num_tensor_lists; i++) {
for (size_t j = 0; j < num_tensors_per_list; j++) {
if (tensor_lists[i][j]->numel() > max_size) {
max_size = tensor_lists[i][j]->numel();
if (max_size >= INT_MAX) {
requires_64bit_indexing = true;
break;
......@@ -616,69 +609,70 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
}
}
const auto g_in_type = tensor_lists[0][0].scalar_type();
const auto p_in_type = tensor_lists[1][0].scalar_type();
auto tl_size = tensor_lists.size();
const auto g_in_type_te = tensor_lists[0][0]->dtype();
const auto p_in_type_te = tensor_lists[1][0]->dtype();
// case 4: g, p, m, v
// case 5: g, p, m, v, p_master
TORCH_CHECK(tl_size == 4 || tl_size == 5, "tensor list must contain 4 or 5");
NVTE_CHECK(num_tensor_lists == 4 || num_tensor_lists == 5, "tensor list must contain 4 or 5");
if (requires_64bit_indexing) {
if (tl_size == 4) {
if (num_tensor_lists == 4) {
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 4>((int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int64_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
AdamFunctor<p_in_type, g_in_type, float, int64_t>(), device_id,
stream, beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);));
} else {
// g, p, m, v, p_master
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<BLOCK_SIZE, 5>((int64_t)chunk_size, noop_flag,
tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int64_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5>(
(int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<p_in_type, g_in_type, float, int64_t>(), device_id, stream,
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
}
} else {
if (tl_size == 4) {
if (num_tensor_lists == 4) {
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
AdamFunctor<scalar_t_0, scalar_t_1, float, int32_t>(), beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
AdamFunctor<p_in_type, g_in_type, float, int32_t>(), device_id,
stream, beta1, beta2, bias_correction1, bias_correction2,
epsilon, lr, (adamMode_t)mode, weight_decay);));
} else {
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
p_in_type_te, p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5>(chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<scalar_t_0, scalar_t_1, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
AdamFunctorMaster<p_in_type, g_in_type, float, int32_t>(),
device_id, stream, beta1, beta2, bias_correction1,
bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
}
}
AT_CUDA_CHECK(cudaGetLastError());
NVTE_CHECK_CUDA(cudaGetLastError());
}
void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay) {
using namespace at;
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
......@@ -687,34 +681,34 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag
bias_correction2 = 1 - std::pow(beta2, step);
}
const auto g_in_type = tensor_lists[0][0].scalar_type();
const auto p_in_type = tensor_lists[1][0].scalar_type();
auto tl_size = tensor_lists.size();
const auto g_in_type_te = tensor_lists[0][0]->dtype();
const auto p_in_type_te = tensor_lists[1][0]->dtype();
// case 5: g, p, m, v, p_master
TORCH_CHECK(tl_size == 5, "tensor list must contain 5");
TORCH_CHECK(p_in_type == at::ScalarType::BFloat16,
"Adam with BF16 param remainders requires BF16 params");
NVTE_CHECK(num_tensor_lists == 5, "tensor list must contain 5");
NVTE_CHECK(p_in_type_te == DType::kBFloat16,
"Adam with BF16 param remainders requires BF16 params");
// g, p, m, v, p_master
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
p_in_type, 0, "adam",
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 1, "adam",
multi_tensor_apply<BLOCK_SIZE, 5>((int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMasterParamRemainder<scalar_t_1, float, int64_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay);));
AT_CUDA_CHECK(cudaGetLastError());
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5>((int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMasterParamRemainder<g_in_type, float, int64_t>(), device_id,
stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr,
(adamMode_t)mode, weight_decay););
NVTE_CHECK_CUDA(cudaGetLastError());
}
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, DType fp8_dtype) {
using namespace at;
const float weight_decay, const DType fp8_dtype,
const int device_id, cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
......@@ -725,10 +719,10 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
size_t max_size = 0;
bool requires_64bit_indexing = false;
for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
for (auto it2 = it->begin(); it2 != it->end(); it2++) {
if (it2->numel() > max_size) {
max_size = it2->numel();
for (size_t i = 0; i < num_tensor_lists; i++) {
for (size_t j = 0; j < num_tensors_per_list; j++) {
if (tensor_lists[i][j]->numel() > max_size) {
max_size = tensor_lists[i][j]->numel();
if (max_size >= INT_MAX) {
requires_64bit_indexing = true;
break;
......@@ -740,66 +734,147 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
}
}
const auto g_in_type = tensor_lists[0][0].scalar_type();
auto tl_size = tensor_lists.size();
const auto g_in_type_te = tensor_lists[0][0]->dtype();
// case 8: g, p_fp8, m, v, p_master, scale, amax, scale_inv
TORCH_CHECK(tl_size == 8, "tensor list must contain 8 tensors");
NVTE_CHECK(num_tensor_lists == 8, "tensor list must contain 8 tensors");
if (requires_64bit_indexing) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 0, "adam",
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5, true>(
(int64_t)chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, float, int64_t>(), beta1, beta2,
bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);));
AdamFunctorMaster<FP8_T, g_in_type, float, int64_t>(), device_id, stream, beta1,
beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
} else {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
fp8_dtype, FP8_T,
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
g_in_type, 0, "adam",
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
g_in_type_te, g_in_type,
multi_tensor_apply<BLOCK_SIZE, 5, true>(chunk_size, noop_flag, tensor_lists,
AdamFunctorMaster<FP8_T, scalar_t_0, float, int32_t>(),
beta1, beta2, bias_correction1, bias_correction2, epsilon,
lr, (adamMode_t)mode, weight_decay);));
AdamFunctorMaster<FP8_T, g_in_type, float, int32_t>(),
device_id, stream, beta1, beta2, bias_correction1,
bias_correction2, epsilon, lr, (adamMode_t)mode,
weight_decay);));
}
AT_CUDA_CHECK(cudaGetLastError());
NVTE_CHECK_CUDA(cudaGetLastError());
}
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale) {
using namespace at;
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam",
void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, Tensor lr,
const float beta1, const float beta2, const float epsilon,
Tensor step, const int mode, const int bias_correction,
const float weight_decay, Tensor inv_scale,
const int device_id, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
AdamCapturableFunctor<scalar_t_0, float>(), beta1, beta2,
step.data_ptr<int>(), bias_correction, epsilon, lr.data_ptr<float>(),
(adamMode_t)mode, weight_decay, inv_scale.data_ptr<float>());)
AdamCapturableFunctor<dtype, float>(), device_id, stream, beta1, beta2,
reinterpret_cast<int *>(step.data.dptr), bias_correction, epsilon,
reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode, weight_decay,
reinterpret_cast<float *>(inv_scale.data.dptr));)
AT_CUDA_CHECK(cudaGetLastError());
NVTE_CHECK_CUDA(cudaGetLastError());
}
void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor lr, const float beta1, const float beta2,
const float epsilon, at::Tensor step, const int mode,
void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists,
Tensor lr, const float beta1, const float beta2,
const float epsilon, Tensor step, const int mode,
const int bias_correction, const float weight_decay,
at::Tensor inv_scale) {
using namespace at;
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam",
Tensor inv_scale, const int device_id,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<BLOCK_SIZE, 5>(chunk_size, noop_flag, tensor_lists,
AdamCapturableMasterFunctor<scalar_t_0, float>(), beta1, beta2,
step.data_ptr<int>(), bias_correction, epsilon, lr.data_ptr<float>(),
(adamMode_t)mode, weight_decay, inv_scale.data_ptr<float>());)
AdamCapturableMasterFunctor<dtype, float>(), device_id, stream, beta1,
beta2, reinterpret_cast<int *>(step.data.dptr), bias_correction,
epsilon, reinterpret_cast<float *>(lr.data.dptr), (adamMode_t)mode,
weight_decay, reinterpret_cast<float *>(inv_scale.data.dptr));)
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_adam
} // namespace transformer_engine
void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode,
const int bias_correction, const float weight_decay,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
}
void nvte_multi_tensor_adam_param_remainder_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2,
const float epsilon, const int step, const int mode, const int bias_correction,
const float weight_decay, const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_param_remainder_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_param_remainder_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, device_id, stream);
}
void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, const float lr,
const float beta1, const float beta2, const float epsilon,
const int step, const int mode, const int bias_correction,
const float weight_decay, const NVTEDType fp8_dtype,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_fp8_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_fp8_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2,
epsilon, step, mode, bias_correction, weight_decay, static_cast<DType>(fp8_dtype), device_id,
stream);
}
void nvte_multi_tensor_adam_capturable_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_capturable_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_capturable_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(lr), beta1, beta2, epsilon, *reinterpret_cast<Tensor *>(step),
mode, bias_correction, weight_decay, *reinterpret_cast<Tensor *>(inv_scale), device_id,
stream);
}
AT_CUDA_CHECK(cudaGetLastError());
void nvte_multi_tensor_adam_capturable_master_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2,
const float epsilon, NVTETensor step, const int mode, const int bias_correction,
const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_adam_capturable_master_cuda);
using namespace transformer_engine;
multi_tensor_adam::multi_tensor_adam_capturable_master_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(lr), beta1, beta2, epsilon, *reinterpret_cast<Tensor *>(step),
mode, bias_correction, weight_decay, *reinterpret_cast<Tensor *>(inv_scale), device_id,
stream);
}
......@@ -4,23 +4,21 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <limits>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include <sstream>
#include "common/recipe/recipe_common.cuh"
#include "common/utils.cuh"
#include "../recipe/recipe_common.cuh"
#include "../utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_compute_scale {
#define BLOCK_SIZE 256
......@@ -57,12 +55,29 @@ struct ComputeScaleAndScaleInvFunctor {
}
};
void multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
float max_fp8, bool force_pow_2_scales, float epsilon) {
using namespace at;
void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists,
float max_fp8, bool force_pow_2_scales,
float epsilon, const int device_id,
cudaStream_t stream) {
multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists,
ComputeScaleAndScaleInvFunctor(), max_fp8, force_pow_2_scales, epsilon);
AT_CUDA_CHECK(cudaGetLastError());
ComputeScaleAndScaleInvFunctor(), device_id, stream, max_fp8,
force_pow_2_scales, epsilon);
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_compute_scale
} // namespace transformer_engine
void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(
int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_compute_scale_and_scale_inv_cuda);
using namespace transformer_engine;
multi_tensor_compute_scale::multi_tensor_compute_scale_and_scale_inv_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8,
force_pow_2_scales, epsilon, device_id, stream);
}
......@@ -4,18 +4,16 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "../utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_l2norm {
#define BLOCK_SIZE 512
#define ILP 4
......@@ -31,6 +29,96 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, int s
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; // NOLINT(*)
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes(T *x, T val, int lanes = 1,
bool share_result = false) { // lanes is intended to be <= 32.
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#ifdef __HIP_PLATFORM_AMD__
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down(final, i);
#else
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i);
#endif
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
}
__syncthreads();
// Avoid potential write before read race when reduce_block_into_lanes is called back to back
return final;
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1,
bool share_result = false) { // lanes is intended to be <= 32.
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
#ifdef __HIP_PLATFORM_AMD__
final = fmaxf(fabsf(final), fabsf(__shfl_down(final, i)));
#else
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
#endif
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template <typename x_t>
struct L2NormFunctor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem,
......@@ -56,7 +144,7 @@ struct L2NormFunctor {
x_t r_x[ILP];
for (int i = 0; i < ILP; i++) {
vals[i] = 0.f;
r_x[i] = 0;
r_x[i] = 0.f;
}
// to make things simple, we put aligned case in a different code path
......@@ -126,7 +214,7 @@ struct UnscaleL2NormFunctor {
x_t r_x[ILP];
for (int i = 0; i < ILP; i++) {
vals[i] = 0.f;
r_x[i] = 0;
r_x[i] = 0.f;
}
// to make things simple, we put aligned case in a different code path
......@@ -310,103 +398,96 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
}
}
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) {
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<BLOCK_SIZE, 1>(chunk_size, noop_flag, tensor_lists,
L2NormFunctor<scalar_t_0>(), output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor,
max_chunks_per_tensor);)
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, Tensor output,
Tensor output_per_tensor, Tensor ret, Tensor ret_per_tensor,
bool per_tensor, int max_chunks_per_tensor, const int device_id,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<BLOCK_SIZE, 1>(
chunk_size, noop_flag, tensor_lists, L2NormFunctor<dtype>(), device_id,
stream, reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);)
NVTE_CHECK_CUDA(cudaGetLastError());
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
auto ret = at::empty({1}, output.options());
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
auto stream = at::cuda::getCurrentCUDAStream();
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
ret.data_ptr<float>(), per_tensor ? ret_per_tensor.data_ptr<float>() : nullptr, per_tensor,
const OptionalCUDAGuard device_guard(device_id);
cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>(
reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr,
reinterpret_cast<float *>(ret.data.dptr),
per_tensor ? reinterpret_cast<float *>(ret_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
}
std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor inv_scale, at::optional<bool> per_tensor_python) {
bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
if (per_tensor) {
for (int t = 0; t < ntensors; t++) {
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
if (max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors * max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
} else {
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_unscale_l2norm_cuda",
multi_tensor_apply<BLOCK_SIZE, 1>(chunk_size, noop_flag, tensor_lists,
UnscaleL2NormFunctor<scalar_t_0>(), inv_scale.data_ptr<float>(),
output.data_ptr<float>(),
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr, per_tensor,
max_chunks_per_tensor);)
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists,
Tensor output, Tensor output_per_tensor, Tensor ret,
Tensor ret_per_tensor, Tensor inv_scale, bool per_tensor,
int max_chunks_per_tensor, const int device_id,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), dtype,
multi_tensor_apply<BLOCK_SIZE, 1>(
chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor<dtype>(), device_id,
stream, reinterpret_cast<float *>(inv_scale.data.dptr),
reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);)
NVTE_CHECK_CUDA(cudaGetLastError());
// This involves one more small kernel launches, but will be negligible end to end.
// I could get rid of these by hacking the functor + multi tensor harness with persistence
// logic, but keeping it simple for now
auto ret = at::empty({1}, output.options());
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
auto stream = at::cuda::getCurrentCUDAStream();
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
output.data_ptr<float>(), per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
ret.data_ptr<float>(), per_tensor ? ret_per_tensor.data_ptr<float>() : nullptr, per_tensor,
const OptionalCUDAGuard device_guard(device_id);
cleanup<<<per_tensor ? tensor_lists[0].size() : 1, 512, 0, stream>>>(
reinterpret_cast<float *>(output.data.dptr),
per_tensor ? reinterpret_cast<float *>(output_per_tensor.data.dptr) : nullptr,
reinterpret_cast<float *>(ret.data.dptr),
per_tensor ? reinterpret_cast<float *>(ret_per_tensor.data.dptr) : nullptr, per_tensor,
max_chunks_per_tensor);
}
} // namespace multi_tensor_l2norm
} // namespace transformer_engine
void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, int per_tensor,
int max_chunks_per_tensor, const int device_id,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_l2norm_cuda);
using namespace transformer_engine;
multi_tensor_l2norm::multi_tensor_l2norm_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(output), *reinterpret_cast<Tensor *>(output_per_tensor),
*reinterpret_cast<Tensor *>(ret), *reinterpret_cast<Tensor *>(ret_per_tensor), per_tensor,
max_chunks_per_tensor, device_id, stream);
}
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag,
NVTETensor **tensor_lists, const size_t num_tensor_lists,
const size_t num_tensors_per_list, NVTETensor output,
NVTETensor output_per_tensor, NVTETensor ret,
NVTETensor ret_per_tensor, NVTETensor inv_scale,
int per_tensor, int max_chunks_per_tensor,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_unscale_l2norm_cuda);
using namespace transformer_engine;
multi_tensor_l2norm::multi_tensor_unscale_l2norm_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
*reinterpret_cast<Tensor *>(output), *reinterpret_cast<Tensor *>(output_per_tensor),
*reinterpret_cast<Tensor *>(ret), *reinterpret_cast<Tensor *>(ret_per_tensor),
*reinterpret_cast<Tensor *>(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream);
}
......@@ -5,17 +5,62 @@
************************************************************************/
#pragma once
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "common/common.h"
#include "../common.h"
// This header is the one-stop shop for all your multi-tensor apply needs.
// Change device if needed.
class OptionalCUDAGuard {
public:
explicit OptionalCUDAGuard(int new_device) {
if (new_device < 0) return;
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
if (new_device != current_device) {
NVTE_CHECK_CUDA(cudaSetDevice(new_device));
device_changed_ = true;
prev_device_ = current_device;
}
}
OptionalCUDAGuard(const OptionalCUDAGuard &) = delete;
OptionalCUDAGuard &operator=(const OptionalCUDAGuard &) = delete;
OptionalCUDAGuard(OptionalCUDAGuard &&other) noexcept
: prev_device_(other.prev_device_), device_changed_(other.device_changed_) {
other.device_changed_ = false;
}
OptionalCUDAGuard &operator=(OptionalCUDAGuard &&other) noexcept {
if (this != &other) {
if (device_changed_) {
cudaSetDevice(prev_device_);
}
prev_device_ = other.prev_device_;
device_changed_ = other.device_changed_;
other.device_changed_ = false;
}
return *this;
}
~OptionalCUDAGuard() {
if (device_changed_) {
NVTE_CHECK_CUDA(cudaSetDevice(prev_device_));
}
}
private:
int prev_device_;
bool device_changed_ = false;
};
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24};
constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320};
......@@ -46,62 +91,40 @@ __global__ void __launch_bounds__(block_size) multi_tensor_apply_kernel(int64_t
}
template <int64_t block_size, int depth, bool USE_FP8 = false, typename T, typename... ArgTypes>
void multi_tensor_apply(int64_t chunk_size, const at::Tensor &noop_flag,
const std::vector<std::vector<at::Tensor>> &tensor_lists, T callable,
ArgTypes... args) {
if constexpr (USE_FP8) {
TORCH_CHECK(tensor_lists.size() == depth + 3,
"tensor_lists.size() != depth + 3, tensor_lists should have 3 more tensors (scale, "
"amax, scale_inv) for fp8");
} else {
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
}
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
for (int l = 0; l < depth; l++) { // No range-based for because I need indices
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for (int t = 0; t < tensor_lists[l].size(); t++) {
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
contiguous_memory =
(contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) ||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d));
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
"A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
}
}
void multi_tensor_apply(int64_t chunk_size,
const transformer_engine::Tensor &noop_flag,
std::vector<std::vector<transformer_engine::Tensor *>> tensor_lists,
T callable, const int device_id, cudaStream_t stream, ArgTypes... args) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
if constexpr (USE_FP8) {
TORCH_CHECK(tensor_lists[depth].size() == len0 && tensor_lists[depth + 1].size() == len0,
"Size mismatch among tensor lists");
NVTE_CHECK(num_tensor_lists == depth + 3,
"tensor_lists.size() != depth + 3, tensor_lists should have 3 more tensors (scale, "
"amax, scale_inv) for fp8");
} else {
NVTE_CHECK(num_tensor_lists == depth, "tensor_lists.size() != depth");
}
int ntensors = tensor_lists[0].size();
TensorListMetadata<depth, USE_FP8> tl;
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream();
const OptionalCUDAGuard device_guard(device_id);
tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
auto kernel = &multi_tensor_apply_kernel<block_size, TensorListMetadata<depth, USE_FP8>, T, ArgTypes...>;
for (int t = 0; t < ntensors; t++) {
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for (int t = 0; t < num_tensors_per_list; t++) {
tl.sizes[loc_tensor_info] = tensor_lists[0][t]->numel();
for (int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t]->data.dptr;
if constexpr (USE_FP8) {
for (int i = 0; i < 3; i++)
tl.fp8_meta_addresses[i][loc_tensor_info] = tensor_lists[depth + i][t].data_ptr();
tl.fp8_meta_addresses[i][loc_tensor_info] = tensor_lists[depth + i][t]->data.dptr;
}
loc_tensor_info++;
auto chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
auto chunks_this_tensor = (tensor_lists[0][t]->numel() + chunk_size - 1) / chunk_size;
for (auto chunk = 0; chunk < chunks_this_tensor; chunk++) {
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
......@@ -111,12 +134,12 @@ void multi_tensor_apply(int64_t chunk_size, const at::Tensor &noop_flag,
bool tensors_full =
(loc_tensor_info == depth_to_max_tensors[depth - 1] && chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
bool last_chunk = (t == num_tensors_per_list - 1 && chunk == chunks_this_tensor - 1);
if (tensors_full || blocks_full || last_chunk) {
kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size, noop_flag.data_ptr<int>(), tl, callable, args...);
chunk_size, reinterpret_cast<int *>(noop_flag.data.dptr), tl, callable, args...);
AT_CUDA_CHECK(cudaGetLastError());
NVTE_CHECK_CUDA(cudaGetLastError());
// Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0;
......
......@@ -4,19 +4,20 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include <cuda_fp8.h>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include <iostream>
#include <sstream>
#include "../utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_scale {
#define BLOCK_SIZE 512
#define ILP 4
......@@ -66,7 +67,7 @@ struct ScaleFunctor {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]);
finite = finite && isfinite(static_cast<float>(r_in[ii]));
}
// store
load_store(out, r_out, i_start, 0);
......@@ -76,7 +77,7 @@ struct ScaleFunctor {
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_in[ii] = 0;
r_in[ii] = 0.f;
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) r_in[ii] = in[i];
}
......@@ -88,7 +89,7 @@ struct ScaleFunctor {
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]);
finite = finite && isfinite(static_cast<float>(r_in[ii]));
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
......@@ -101,20 +102,29 @@ struct ScaleFunctor {
}
};
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float scale) {
using namespace at;
// The output (downscaled) type is always float.
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda",
DISPATCH_FLOAT_HALF_AND_BFLOAT(
tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda",
void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists, float scale,
const int device_id, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[0][0]->dtype(), p_in_type,
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
tensor_lists[1][0]->dtype(), g_in_type,
multi_tensor_apply<BLOCK_SIZE, 2>(chunk_size, noop_flag, tensor_lists,
ScaleFunctor<scalar_t_0, scalar_t_1>(), scale);))
AT_CUDA_CHECK(cudaGetLastError());
ScaleFunctor<p_in_type, g_in_type>(), device_id, stream, scale);))
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_scale
} // namespace transformer_engine
void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float scale, const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_scale_cuda);
using namespace transformer_engine;
// AT_CUDA_CHECK(cudaDeviceSynchronize());
multi_tensor_scale::multi_tensor_scale_cuda(
chunk_size, *reinterpret_cast<Tensor *>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, device_id,
stream);
}
......@@ -4,14 +4,16 @@
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h>
#include "../utils.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
namespace transformer_engine {
namespace multi_tensor_sgd {
#define BLOCK_SIZE 512
#define ILP 4
......@@ -54,9 +56,9 @@ struct SGDFunctor {
T_weight* mom_in = reinterpret_cast<T_weight*>(tl.addresses[2][tensor_loc]);
mom_in += chunk_idx * chunk_size;
at::Half* model_weights_out = nullptr;
fp16* model_weights_out = nullptr;
if (N == 4) {
model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];
model_weights_out = reinterpret_cast<fp16*>(tl.addresses[3][tensor_loc]);
model_weights_out += chunk_idx * chunk_size;
}
......@@ -112,7 +114,7 @@ struct SGDFunctor {
weight_in[i] += (-lr * incoming_grads[ii]);
// if necessary, write out an fp16 copy of the weights
if (N == 4) model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
if (N == 4) model_weights_out[i] = static_cast<fp16>(weight_in[i]);
// also write out the new momentum
if (momentum != 0.f) mom_in[i] = incoming_moms[ii];
......@@ -122,23 +124,23 @@ struct SGDFunctor {
}
};
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float wd,
float momentum, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale) {
auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].scalar_type();
auto weight_type = tensor_lists[1][0].scalar_type();
if (num_tensors == 4) {
for (int i = 0; i < tensor_lists[3].size(); i++)
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
"Additional output tensors should always be fp16.");
void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor*>> tensor_lists, float wd, float momentum,
float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale, const int device_id,
cudaStream_t stream) {
const size_t num_tensor_lists = tensor_lists.size();
const size_t num_tensors_per_list = tensor_lists[0].size();
auto grad_type = tensor_lists[0][0]->dtype();
auto weight_type = tensor_lists[1][0]->dtype();
if (num_tensor_lists == 4) {
for (int i = 0; i < num_tensors_per_list; i++)
NVTE_CHECK(tensor_lists[3][i]->dtype() == DType::kFloat16,
"Additional output tensors should always be fp16.");
}
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(),
"expected noop flag to be on the same device as tensors");
// We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
......@@ -150,53 +152,51 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
// we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No
if (grad_type == at::ScalarType::Half && weight_type == at::ScalarType::Half &&
num_tensors == 3) {
if (grad_type == DType::kFloat16 && weight_type == DType::kFloat16 && num_tensor_lists == 3) {
multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, at::Half, at::Half>(), wd, momentum, dampening, lr,
nesterov, first_run, wd_after_momentum, scale);
SGDFunctor<3, fp16, fp16>(), device_id, stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
}
// Case 2. fp16, fp32, fp32, No
// else if (grad_type == at::ScalarType::Half &&
// weight_type == at::ScalarType::Float &&
// num_tensors == 3) {
// multi_tensor_apply<BLOCK_SIZE, 3>(
// chunk_size,
// noop_flag,
// tensor_lists,
// SGDFunctor<3, at::Half, float>(),
// wd,
// momentum,
// dampening,
// lr,
// nesterov,
// first_run,
// wd_after_momentum);
// }
// Case 2. fp32, fp32, fp32, No
else if (grad_type == at::ScalarType::Float && // NOLINT(*)
weight_type == at::ScalarType::Float && num_tensors == 3) {
else if (grad_type == DType::kFloat32 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 3) {
multi_tensor_apply<BLOCK_SIZE, 3>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<3, float, float>(), wd, momentum, dampening, lr, nesterov,
first_run, wd_after_momentum, scale);
SGDFunctor<3, float, float>(), device_id, stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
}
// Case 3. fp16, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Half && // NOLINT(*)
weight_type == at::ScalarType::Float && num_tensors == 4) {
else if (grad_type == DType::kFloat16 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 4) {
multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, at::Half, float>(), wd, momentum, dampening, lr, nesterov,
first_run, wd_after_momentum, scale);
SGDFunctor<4, fp16, float>(), device_id, stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
}
// Case 4. fp32, fp32, fp32, Yes
else if (grad_type == at::ScalarType::Float && // NOLINT(*)
weight_type == at::ScalarType::Float && num_tensors == 4) {
else if (grad_type == DType::kFloat32 && // NOLINT(*)
weight_type == DType::kFloat32 && num_tensor_lists == 4) {
multi_tensor_apply<BLOCK_SIZE, 4>(chunk_size, noop_flag, tensor_lists,
SGDFunctor<4, float, float>(), wd, momentum, dampening, lr, nesterov,
first_run, wd_after_momentum, scale);
SGDFunctor<4, float, float>(), device_id, stream, wd, momentum, dampening,
lr, nesterov, first_run, wd_after_momentum, scale);
} else {
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
NVTE_ERROR("Unsupported combination of weight and gradient types.");
}
AT_CUDA_CHECK(cudaGetLastError());
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_sgd
} // namespace transformer_engine
void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor** tensor_lists,
const size_t num_tensor_lists, const size_t num_tensors_per_list,
float wd, float momentum, float dampening, float lr, int nesterov,
int first_run, int wd_after_momentum, float scale,
const int device_id, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_sgd_cuda);
using namespace transformer_engine;
multi_tensor_sgd::multi_tensor_sgd_cuda(
chunk_size, *reinterpret_cast<Tensor*>(noop_flag),
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), wd, momentum,
dampening, lr, nesterov, first_run, wd_after_momentum, scale, device_id, stream);
}
......@@ -6,6 +6,8 @@
#include <transformer_engine/permutation.h>
#include <cub/cub.cuh>
#include "../common.h"
#ifdef __HIP_PLATFORM_AMD__
......@@ -385,3 +387,11 @@ void nvte_unpermute(const NVTETensor input, NVTETensor output, NVTETensor row_id
reinterpret_cast<const float *>(prob_cu->data.dptr), num_rows, topK,
num_cols, stream););
}
void nvte_device_radix_sort_pairs(void *temp_storage, size_t *temp_storage_bytes, int *keys_in,
int *keys_out, int *values_in, int *values_out,
size_t num_items) {
NVTE_API_CALL(nvte_device_radix_sort_pairs);
cub::DeviceRadixSort::SortPairs(temp_storage, *temp_storage_bytes, keys_in, keys_out, values_in,
values_out, num_items);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/recipe.h>
#include <cassert>
#include "../common.h"
#include "../utils.cuh"
namespace transformer_engine {
namespace fp8_block_scaling_recipe {
constexpr int kTileDim = 128;
constexpr int kThreadsPerBlock = 256;
template <typename IType>
__global__ void __launch_bounds__(kThreadsPerBlock)
fp8_block_scaling_compute_partial_amax_kernel(const IType *input, float *amax_ptr,
const size_t amax_stride_h,
const size_t amax_stride_w, const size_t h,
const size_t w, const size_t start_offset,
const size_t len) {
constexpr int kThreadsPerWarp = 32;
constexpr int kLoopsPerRow = kTileDim / kThreadsPerWarp;
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;
constexpr int kLoopsPerCol = kTileDim / kNumWarps;
const int tile_col = blockIdx.x;
const int tile_row = blockIdx.y;
const size_t end_offset = start_offset + len;
const IType *input_minus_offset = input - start_offset;
__shared__ float smem[kNumWarps];
float amax = 0.0f;
for (int loop_col = 0; loop_col < kLoopsPerCol; ++loop_col) {
size_t r = tile_row * kTileDim + loop_col * kNumWarps + threadIdx.x / kThreadsPerWarp;
for (int loop_row = 0; loop_row < kLoopsPerRow; ++loop_row) {
size_t c = tile_col * kTileDim + loop_row * kThreadsPerWarp + (threadIdx.x % kThreadsPerWarp);
size_t idx = r * w + c;
if (r < h && c < w && idx >= start_offset && idx < end_offset) {
float other_amax = fabs(static_cast<float>(input_minus_offset[idx]));
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
}
}
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
float other_amax = __shfl_down_sync(0xFFFFFFFF, amax, delta);
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
if (threadIdx.x % kThreadsPerWarp == 0) {
smem[threadIdx.x / kThreadsPerWarp] = amax;
}
__syncthreads();
if (threadIdx.x == 0) {
for (int i = 0; i < kNumWarps; ++i) {
float other_amax = smem[i];
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
amax_ptr[tile_row * amax_stride_h + tile_col * amax_stride_w] = amax;
}
}
template <typename IType, typename OType, bool kWidthAligned>
__global__ void __launch_bounds__(kThreadsPerBlock)
fp8_block_scaling_partial_cast_kernel(const IType *input, OType *output, const float *scale_ptr,
const size_t scale_stride_h, const size_t scale_stride_w,
const size_t h, const size_t w, const size_t start_offset,
const size_t len) {
using transformer_engine::Vec;
static_assert(sizeof(OType) == 1);
constexpr int kNumOutputElemsPerBank = 4 / sizeof(OType);
constexpr int kThreadsPerWarp = 32;
constexpr int kLoopsPerRow = kTileDim / kThreadsPerWarp;
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;
constexpr int kRowsPerWarp = kTileDim / kNumWarps;
__shared__ OType smem[kTileDim][kTileDim + kNumOutputElemsPerBank];
const int tile_w = blockIdx.x;
const int tile_h = blockIdx.y;
const size_t end_offset = start_offset + len;
const IType *input_minus_offset = input - start_offset;
OType *output_minus_offset = output - start_offset;
const float scale = scale_ptr[tile_h * scale_stride_h + tile_w * scale_stride_w];
// Load input data into shared memory
bool skip_store = true;
for (int i = 0; i < kRowsPerWarp; ++i) {
for (int j = 0; j < kLoopsPerRow; ++j) {
const int h_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i;
const int w_in_smem = threadIdx.x % kThreadsPerWarp + kThreadsPerWarp * j;
const int h_in_input = tile_h * kTileDim + h_in_smem;
const int w_in_input = tile_w * kTileDim + w_in_smem;
const size_t idx_in_input = static_cast<size_t>(h_in_input) * w + w_in_input;
if (h_in_input < h && w_in_input < w && idx_in_input >= start_offset &&
idx_in_input < end_offset) {
float inp = static_cast<float>(input_minus_offset[idx_in_input]) * scale;
smem[h_in_smem][w_in_smem] = static_cast<OType>(inp);
skip_store = false;
}
}
}
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {
bool other_skip_store = __shfl_down_sync(0xFFFFFFFF, skip_store, delta);
skip_store = skip_store && other_skip_store;
}
skip_store = __shfl_sync(0xFFFFFFFF, skip_store, 0);
if (skip_store) {
return;
}
// Store the casted data into the output.
// Note that this store operation might write "out-of-bounds", but it is intentional:
// 1. The "out-of-bounds" here only crosses the boundary of the "local shard" (i.e., the region
// from start_offset to end_offset), not the boundary of the entire output memory. Therefore,
// this out-of-bounds write will not cause illegal memory access.
// 2. We assume that the subsequent all-gather operation happens in-place, so any parts that
// should not be updated here will be overwritten by the all-gather.
// This tricky approach allows us to avoid checking whether each output index falls within
// [start, end), resulting in a significant performance improvement.
Vec<OType, kNumOutputElemsPerBank> vec_output;
for (int i = 0; i < kRowsPerWarp; ++i) {
const int row_in_smem = threadIdx.x / kThreadsPerWarp * kRowsPerWarp + i;
const int col_in_smem = threadIdx.x % kThreadsPerWarp * kNumOutputElemsPerBank;
for (int j = 0; j < kNumOutputElemsPerBank; ++j) {
vec_output.data.elt[j] = smem[row_in_smem][col_in_smem + j];
}
const int row_in_output = tile_h * kTileDim + row_in_smem;
const int col_in_output = tile_w * kTileDim + col_in_smem;
const size_t idx_in_output = static_cast<size_t>(row_in_output) * w + col_in_output;
if (row_in_output < h) {
if constexpr (kWidthAligned) {
vec_output.store_to(output_minus_offset + idx_in_output);
} else {
int num = min(static_cast<size_t>(kNumOutputElemsPerBank),
static_cast<size_t>(col_in_output < w ? w - col_in_output : 0));
vec_output.store_to_elts(output_minus_offset, idx_in_output, num);
}
}
}
}
void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_t h, size_t w,
size_t amax_stride_h, size_t amax_stride_w,
size_t start_offset, size_t block_len,
cudaStream_t stream) {
NVTE_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
size_t len = inp.numel();
assert(h > 0 && w > 0);
assert(start_offset < h * w);
assert(start_offset + len <= h * w);
size_t blocks_x = (w + kTileDim - 1) / kTileDim;
size_t blocks_y = (h + kTileDim - 1) / kTileDim;
assert(blocks_x <= std::numeric_limits<unsigned int>::max());
assert(blocks_y <= std::numeric_limits<unsigned int>::max());
dim3 grid(blocks_x, blocks_y);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
inp.dtype(), inp_dtype,
fp8_block_scaling_compute_partial_amax_kernel<inp_dtype>
<<<grid, kThreadsPerBlock, 0, stream>>>(reinterpret_cast<const inp_dtype *>(inp.data.dptr),
reinterpret_cast<float *>(amax.data.dptr),
amax_stride_h, amax_stride_w, h, w, start_offset,
len);)
}
void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor scale, size_t h,
size_t w, size_t scale_stride_h, size_t scale_stride_w,
size_t start_offset, size_t block_len, const DType out_dtype,
cudaStream_t stream) {
NVTE_CHECK(block_len == 128, "Currently only block_len = 128 is supported");
size_t len = inp.numel();
assert(h > 0 && w > 0);
assert(start_offset < h * w);
assert(start_offset + len <= h * w);
size_t blocks_x = (w + kTileDim - 1) / kTileDim;
size_t blocks_y = (h + kTileDim - 1) / kTileDim;
assert(blocks_x <= std::numeric_limits<unsigned int>::max());
assert(blocks_y <= std::numeric_limits<unsigned int>::max());
dim3 grid(blocks_x, blocks_y);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
inp.dtype(), inp_dtype,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
out_dtype, fp8_type,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
w % kTileDim == 0, kWidthAligned,
fp8_block_scaling_partial_cast_kernel<inp_dtype, fp8_type, kWidthAligned>
<<<grid, kThreadsPerBlock, 0, stream>>>(
reinterpret_cast<const inp_dtype *>(inp.data.dptr),
reinterpret_cast<fp8_type *>(out.data.dptr),
reinterpret_cast<const float *>(scale.data.dptr), scale_stride_h, scale_stride_w,
h, w, start_offset, len);)))
}
} // namespace fp8_block_scaling_recipe
} // namespace transformer_engine
void nvte_fp8_block_scaling_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h,
size_t w, size_t amax_stride_h,
size_t amax_stride_w, size_t start_offset,
size_t block_len, cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_block_scaling_compute_partial_amax);
using namespace transformer_engine;
fp8_block_scaling_recipe::fp8_block_scaling_compute_partial_amax(
*reinterpret_cast<const Tensor *>(inp), *reinterpret_cast<Tensor *>(amax), h, w,
amax_stride_h, amax_stride_w, start_offset, block_len, stream);
}
void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
const NVTETensor scale, size_t h, size_t w,
size_t scale_stride_h, size_t scale_stride_w,
size_t start_offset, size_t block_len,
const NVTEDType out_dtype, cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_block_scaling_partial_cast);
using namespace transformer_engine;
fp8_block_scaling_recipe::fp8_block_scaling_partial_cast(
*reinterpret_cast<const Tensor *>(inp), *reinterpret_cast<Tensor *>(out),
*reinterpret_cast<const Tensor *>(scale), h, w, scale_stride_h, scale_stride_w, start_offset,
block_len, static_cast<DType>(out_dtype), stream);
}
......@@ -10,6 +10,7 @@
#include <iostream>
#include "common.h"
#include "common/util/cuda_runtime.h"
namespace transformer_engine {
......@@ -48,11 +49,11 @@ std::string to_string(const DType type) {
std::string to_string(const NVTEScalingMode &mode) {
switch (mode) {
case NVTE_DELAYED_TENSOR_SCALING:
return "Delayed Tensor Scaling";
return "NVTE_DELAYED_TENSOR_SCALING";
case NVTE_MXFP8_1D_SCALING:
return "MXFP8 1D Scaling";
return "NVTE_MXFP8_1D_SCALING";
case NVTE_INVALID_SCALING:
return "Invalid Scaling";
return "NVTE_INVALID_SCALING";
}
return "Invalid Scaling";
}
......@@ -214,15 +215,13 @@ NVTEDType nvte_tensor_type(const NVTETensor tensor) {
NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
NVTEShape ret;
if (ndim == 0) {
ret.data = nullptr;
ret.ndim = 0;
return ret;
}
NVTE_CHECK(ndim <= sizeof(ret.owned_data) / sizeof(ret.owned_data[0]),
NVTE_CHECK(ndim <= sizeof(ret.data) / sizeof(ret.data[0]),
"Too many dims for NVTEShape (requested: ", ndim,
", max: ", sizeof(ret.owned_data) / sizeof(ret.owned_data[0]), ")");
std::copy(data, data + ndim, ret.owned_data);
ret.data = ret.owned_data;
", max: ", sizeof(ret.data) / sizeof(ret.data[0]), ")");
std::copy(data, data + ndim, ret.data);
ret.ndim = ndim;
return ret;
}
......@@ -350,7 +349,7 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) {
if (tensor == nullptr) {
return {nullptr, kNVTEFloat32, {nullptr, 0}};
return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)};
}
const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
switch (param_name) {
......@@ -483,3 +482,13 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
delete reinterpret_cast<transformer_engine::QuantizationConfig *>(config);
}
}
int nvte_is_non_tn_fp8_gemm_supported() {
int deviceComputeCapability =
transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device());
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
return (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
deviceComputeCapability >= 130;
}
......@@ -134,9 +134,15 @@ bool supports_multicast(int device_id) {
auto init = [&]() {
CUdevice cudev;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &cudev, device_id);
int result;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev);
// Multicast support requires both CUDA12.1 UMD + KMD
int result = 0;
// Check if KMD >= 12.1
int driver_version;
NVTE_CHECK_CUDA(cudaDriverGetVersion(&driver_version));
if (driver_version >= 12010) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &result,
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, cudev);
}
cache[device_id] = static_cast<bool>(result);
};
std::call_once(flags[device_id], init);
......
......@@ -23,10 +23,18 @@
#endif // __HIP_PLATFORM_AMD__
#include <nvrtc.h>
#include <iostream>
#include <stdexcept>
#include "../util/string.h"
#define NVTE_WARN(...) \
do { \
std::cerr << ::transformer_engine::concat_strings( \
__FILE__ ":", __LINE__, " in function ", __func__, ": ", \
::transformer_engine::concat_strings(__VA_ARGS__), "\n"); \
} while (false)
#define NVTE_ERROR(...) \
do { \
throw ::std::runtime_error(::transformer_engine::concat_strings( \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment