"vscode:/vscode.git/clone" did not exist on "2bb532fbf33766c089f4193dd6ef745cae5301d3"
Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
...@@ -190,14 +190,14 @@ Step 2: Cast and store to output_c ...@@ -190,14 +190,14 @@ Step 2: Cast and store to output_c
| ... | | ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 3: Transpose, cast and store to output_t Step 3 (if columnwise transpose is True, GEMM_READY): Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) * shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps * 8 warps
* Loop 2 times * Loop 2 times
* What each thread does in each loop: * What each thread does in each loop:
* 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times * 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times
* Every 8 consecutive threads do reduction and calculate the amax of each column * Every 8 consecutive threads do reduction and calculate the amax of each column
* 16 elements are quantized and write to output_c at a time, for a total of 2 times * 16 elements are quantized and write to output_t at a time, for a total of 2 times
+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+ +------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+
| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | | | T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | |
| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | | | T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | |
...@@ -209,6 +209,29 @@ Step 3: Transpose, cast and store to output_t ...@@ -209,6 +209,29 @@ Step 3: Transpose, cast and store to output_t
| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | | | T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+ +-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+
Step 3 (if columnwise transpose is False, COMPACT format): Skip Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 1 times
* What each thread does in each loop:
* 16 elements (in a row) are read from the shared memory, for a total of 4 rows,
* it needs 8 reads in smem to get 16 elements in a row, thread tile shape is 16x4
* Every 32 consecutive threads in a warp do reduction and calculate the amax of each column,
* so each thread will do warp shuffle 16 times to get the amax of each column
* 16 elements are quantized and write to output_t at a time, for a total of 4 times
+------16 elements-------+------16 elements-------+-----80 elements-----+------16 elements------+
| T0 | | | |
| T1 | | | |
| T2 | | | |
| T3 | | | |
| T4 | | | |
| T5 | | | |
| T6 | | | |
| T7 | | | |
| ... | | | |
| T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+
*/ */
// clang-format on // clang-format on
...@@ -231,6 +254,7 @@ constexpr int kNumThreadsLoad = kTileDim / kNVecIn; ...@@ -231,6 +254,7 @@ constexpr int kNumThreadsLoad = kTileDim / kNVecIn;
constexpr int kNumThreadsStore = kTileDim / kNVecOut; constexpr int kNumThreadsStore = kTileDim / kNVecOut;
static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp");
static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp");
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;
template <bool kAligned, typename CType, typename IType, typename OType> template <bool kAligned, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel(
...@@ -240,9 +264,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -240,9 +264,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scaling) { const bool pow_2_scaling) {
bool return_rowwise = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE; bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE;
bool return_columnwise_transpose = bool return_columnwise_gemm_ready =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE; columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
bool return_columnwise_compact =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT;
using SMemVec = Vec<IType, kNVecSMem>; using SMemVec = Vec<IType, kNVecSMem>;
using OVec = Vec<OType, kNVecOut>; using OVec = Vec<OType, kNVecOut>;
...@@ -439,8 +465,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -439,8 +465,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
} }
} }
// Step 3: Transpose, cast and store to output_t // Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t
if (return_columnwise_transpose) { if (return_columnwise_gemm_ready) {
constexpr int c_stride = constexpr int c_stride =
kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory
constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem);
...@@ -554,6 +580,103 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -554,6 +580,103 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
} }
} }
} }
// Step 4 (return_columnwise_compact): cast in 128x1 style and store to output, skip transpose
if (return_columnwise_compact) {
// thread tile should be 4x16, 16 means 8 smem reads
constexpr int kThreadTileRow = kTileDim / kThreadsPerWarp;
constexpr int kThreadTileCol = kNVecOut;
using RegVec = Vec<IType, kThreadTileCol>;
using RegScaleVec = Vec<CType, kThreadTileCol>;
constexpr int num_smem_reads = kNVecOut / kNVecSMem;
// c_stride will not be used here because we only have one iteration
// constexpr int c_stride = kThreadTileCol * kNumWarps / kNVecSMem;
constexpr int num_iterations =
kTileDim / (kNumWarps * kThreadTileCol); // should be only one iteration
static_assert(num_iterations == 1,
"num_iterations should be 1 for columnwise non-transpose case");
const int thr_idx_in_warp = threadIdx.x % kThreadsPerWarp;
const int warp_idx = threadIdx.x / kThreadsPerWarp;
const int r_s = thr_idx_in_warp * kThreadTileRow; // Row in shared memory
int c_s = warp_idx * num_smem_reads; // Column in shared memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory
const size_t c_g =
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory
const size_t num_ele = c_g < row_length
? min(static_cast<size_t>(kThreadTileCol), row_length - c_g)
: 0; // For not aligned case
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
RegVec reg_vec[kThreadTileRow];
RegScaleVec thr_scale;
// Step 3.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kThreadTileRow; ++i) {
int r = r_s + i;
#pragma unroll
for (int j = 0; j < num_smem_reads; ++j) {
int c = c_s + j;
SMemVec smem_vec = smem[r * kSMemCol + c];
// copy smem_vec to reg vec with its elements
#pragma unroll
for (int k = 0; k < kNVecSMem; ++k) {
reg_vec[i].data.elt[j * kNVecSMem + k] = smem_vec.data.elt[k];
}
}
}
#pragma unroll
for (int reg_idx = 0; reg_idx < kThreadTileCol; ++reg_idx) {
// Step 3.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kThreadTileRow; ++i) {
amax = fmaxf(amax, fabsf(reg_vec[i].data.elt[reg_idx]));
}
// Step 3.3: Reduce amax
const bool is_src_lane = thr_idx_in_warp == 0;
amax = warp_reduce_max<kThreadsPerWarp>(amax);
constexpr int lane_zero = 0;
amax = __shfl_sync(0xFFFFFFFF, amax, lane_zero);
// Step 3.4: Compute scale
CType scale;
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
thr_scale.data.elt[reg_idx] = scale;
// Step 3.5: Write scale_inv_t
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (c_g + reg_idx < row_length);
}
if (write_scale_inv) {
CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.y);
size_t col_idx = static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem + reg_idx;
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
// Step 3.6: Quantize
for (int row_idx = 0; row_idx < kThreadTileRow; ++row_idx) {
OType* output_g =
&output_t[(r_g + row_idx) * row_length + c_g]; // Output address in global memory
OVec output_vec;
#pragma unroll
for (int i = 0; i < kThreadTileCol; ++i) {
output_vec.data.elt[i] = static_cast<OType>(
static_cast<CType>(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i]);
}
// Step 3.7: Store output_t
if constexpr (kAligned) {
output_vec.store_to(output_g);
} else {
if (r_g + row_idx < num_rows) {
output_vec.store_to_elts(output_g, 0, num_ele);
}
}
}
// Step 3.8: Update output address, column index of shared memory
// this section shouldn't matter since we only have one iteration
}
}
} }
} // namespace } // namespace
...@@ -569,11 +692,6 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -569,11 +692,6 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
const bool pow2_scale, cudaStream_t stream) { const bool pow2_scale, cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise); NVTE_API_CALL(quantize_transpose_vector_blockwise);
// assert that rowwise_option and columnwise_option are not both NONE
NVTE_CHECK(rowwise_option != FP8BlockwiseRowwiseOption::NONE ||
columnwise_option != FP8BlockwiseColumnwiseOption::NONE,
"rowwise_option and columnwise_option cannot both be NONE");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_elements = row_length; size_t num_elements = row_length;
size_t num_rows = 1; size_t num_rows = 1;
...@@ -594,32 +712,43 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -594,32 +712,43 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
size_t scale_t_stride_y = 0; size_t scale_t_stride_y = 0;
if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) { if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) {
NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE, NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY ||
rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE_COMPACT,
"Unexpected rowwise enum value"); "Unexpected rowwise enum value");
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2."); NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2.");
size_t scale_k = scale_inv.shape[1]; size_t scale_k = scale_inv.shape[1];
scale_stride_x = scale_k; bool rowwise_compact = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE_COMPACT;
scale_stride_y = 1; scale_stride_x = rowwise_compact ? 1 : scale_k;
scale_stride_y = rowwise_compact ? scale_k : 1;
} }
if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) { if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) {
NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE,
"Unexpected columnwise enum value");
NVTE_CHECK(output_t.shape.size() == input.shape.size(), NVTE_CHECK(output_t.shape.size() == input.shape.size(),
"output_t must have same number of dimensions as input."); "output_t must have same number of dimensions as input.");
if (output_t.shape.size() > 0) { if (output_t.shape.size() > 0) {
NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t."); if (columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY) {
for (size_t i = 1; i < output_t.shape.size(); ++i) { NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t.");
NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t"); for (size_t i = 1; i < output_t.shape.size(); ++i) {
NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t");
}
} else {
NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT,
"Unexpected columnwise option enum value");
NVTE_CHECK(output_t.shape[0] == input.shape[0], "Wrong dimension 0 of output_t.");
NVTE_CHECK(
input.shape == output_t.shape,
"Input and output_t must have the same shape for columnwise non-transpose case.");
} }
} }
NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype."); NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype.");
NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2."); NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2.");
scale_t_stride_x = scale_inv_t.shape[1]; bool columnwise_compact = columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT;
scale_t_stride_y = 1; size_t scale_t_k = scale_inv_t.shape[1];
scale_t_stride_x = columnwise_compact ? 1 : scale_t_k;
scale_t_stride_y = columnwise_compact ? scale_t_k : 1;
} }
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
......
...@@ -288,14 +288,13 @@ void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stre ...@@ -288,14 +288,13 @@ void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stre
NVTE_API_CALL(nvte_transpose); NVTE_API_CALL(nvte_transpose);
using namespace transformer_engine; using namespace transformer_engine;
auto noop = Tensor(); auto noop = Tensor();
transpose(*reinterpret_cast<const Tensor *>(input), noop, reinterpret_cast<Tensor *>(output), transpose(*convertNVTETensorCheck(input), noop, convertNVTETensor(output), stream);
stream);
} }
void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose_with_noop); NVTE_API_CALL(nvte_transpose_with_noop);
using namespace transformer_engine; using namespace transformer_engine;
transpose(*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(noop), transpose(*convertNVTETensorCheck(input), *convertNVTETensorCheck(noop),
reinterpret_cast<Tensor *>(output), stream); convertNVTETensor(output), stream);
} }
...@@ -386,17 +386,18 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/ ...@@ -386,17 +386,18 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/
workspace->data.dtype = DType::kFloat32; workspace->data.dtype = DType::kFloat32;
} else { } else {
// Check that workspace matches expected size // Check that workspace matches expected size
const size_t workspace_size = const size_t workspace_size = get_buffer_size_bytes(
std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1, std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1,
std::multiplies<size_t>()) * std::multiplies<size_t>()),
typeToSize(workspace->data.dtype); workspace->data.dtype);
const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32); const size_t required_size =
get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32);
NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (", NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (",
num_rows_partial_dbias, ",", row_length, "), found ())"); num_rows_partial_dbias, ",", row_length, "), found ())");
NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(",
num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32),
"; found dims=", workspace->data.shape, "; found dims=", workspace->data.shape,
", dtype=", typeToSize(workspace->data.dtype), ")"); ", dtype=", typeToNumBits(workspace->data.dtype), " bits)");
} }
} }
...@@ -513,7 +514,6 @@ void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_outp ...@@ -513,7 +514,6 @@ void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_outp
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_transpose_dbias); NVTE_API_CALL(nvte_fp8_transpose_dbias);
using namespace transformer_engine; using namespace transformer_engine;
fp8_transpose_dbias( fp8_transpose_dbias(*convertNVTETensorCheck(input), convertNVTETensor(transposed_output),
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(transposed_output), convertNVTETensor(dbias), convertNVTETensor(workspace), stream);
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
} }
...@@ -10,13 +10,16 @@ ...@@ -10,13 +10,16 @@
#endif #endif
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <transformer_engine/cast.h> #include <transformer_engine/cast.h>
#include <transformer_engine/multi_stream.h>
#include <cfloat> #include <cfloat>
#include <limits> #include <limits>
#include <mutex>
#include <string> #include <string>
#include "../common.h" #include "../common.h"
#include "../transpose/cast_transpose.h" #include "../transpose/cast_transpose.h"
#include "../util/multi_stream.h"
#include "../util/vectorized_pointwise.h" #include "../util/vectorized_pointwise.h"
#include "../utils.cuh" #include "../utils.cuh"
#include "cast_kernels.cuh" #include "cast_kernels.cuh"
...@@ -156,6 +159,45 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat ...@@ -156,6 +159,45 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dequantize); NVTE_API_CALL(nvte_dequantize);
using namespace transformer_engine; using namespace transformer_engine;
detail::dequantize_helper(*reinterpret_cast<const Tensor *>(input), detail::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
reinterpret_cast<Tensor *>(output), stream); }
void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
const NVTEQuantizationConfig quant_configs,
const size_t num_tensors, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_quantize);
using namespace transformer_engine;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
const size_t num_streams = nvte_get_num_compute_streams();
int num_stream_used = std::min(num_streams, num_tensors);
// wait for current stream to finish
NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream));
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0)));
}
for (int i = 0; i < num_tensors; i++) {
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
inputs[i], grad, outputs[i], dbias, workspace, nullptr,
detail::get_compute_stream(i % num_streams));
}
// record events on compute streams
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(
cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s)));
}
// wait for all compute streams to finish
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
}
} }
...@@ -763,19 +763,20 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -763,19 +763,20 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X,
cols, 0, sizeof(IType)); cols, 0, typeToNumBits(gated_input.dtype()));
} }
const uint32_t tensor_stride_elems = output_cols; const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, 0, sizeof(IType)); SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, cols, sizeof(IType)); SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, 0, sizeof(OType)); SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, cols, sizeof(OType)); SHMEM_DIM_X, tensor_stride_elems, cols,
typeToNumBits(output->dtype()));
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
const size_t buff_size_aligned_in = const size_t buff_size_aligned_in =
...@@ -862,31 +863,33 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -862,31 +863,33 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols, 0, sizeof(IType)); SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype()));
} }
const uint32_t tensor_stride_elems = output_cols; const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, sizeof(IType)); SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0,
typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, sizeof(IType)); SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols,
typeToNumBits(gated_input.dtype()));
if (USE_ROWWISE_SCALING) { if (USE_ROWWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0,
sizeof(OType)); typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols,
sizeof(OType)); typeToNumBits(output->dtype()));
} }
if (USE_COLWISE_SCALING) { if (USE_COLWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data,
rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems,
0, sizeof(OType)); 0, typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data,
rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems,
cols, sizeof(OType)); cols, typeToNumBits(output->dtype()));
} }
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
...@@ -1071,10 +1074,9 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, ...@@ -1071,10 +1074,9 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input,
cudaStream_t stream) { cudaStream_t stream) {
using namespace gated_kernels; using namespace gated_kernels;
Tensor grad_empty_tensor; Tensor grad_empty_tensor;
const Tensor &grad_tensor = const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor;
IS_DGATED ? *(reinterpret_cast<const Tensor *>(grad)) : grad_empty_tensor; const Tensor gated_input_tensor = *convertNVTETensorCheck(gated_input);
const Tensor gated_input_tensor = *reinterpret_cast<const Tensor *>(gated_input); Tensor *output_tensor = convertNVTETensorCheck(output);
Tensor *output_tensor = reinterpret_cast<Tensor *>(output);
if (is_supported_by_CC_100()) { if (is_supported_by_CC_100()) {
quantize_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor, quantize_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor,
......
...@@ -904,15 +904,15 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T ...@@ -904,15 +904,15 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T
alignas(64) CUtensorMap tensor_map_output{}; alignas(64) CUtensorMap tensor_map_output{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype));
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype));
} }
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, sizeof(OType)); FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype));
cast_fp8_2D_kernel<IS_DBIAS, IS_DACT, ParamOP, OP, IType, OType> cast_fp8_2D_kernel<IS_DBIAS, IS_DACT, ParamOP, OP, IType, OType>
<<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_act_input, tensor_map_output, <<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_act_input, tensor_map_output,
...@@ -1004,24 +1004,24 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1004,24 +1004,24 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
alignas(64) CUtensorMap tensor_map_output_colwise{}; alignas(64) CUtensorMap tensor_map_output_colwise{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y,
MXFP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); MXFP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype()));
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols,
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
sizeof(IType)); typeToNumBits(input.dtype()));
} }
if (use_rowwise_scaling) { if (use_rowwise_scaling) {
create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols,
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
sizeof(OType)); typeToNumBits(output->dtype()));
} }
if (use_colwise_scaling) { if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows,
cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
sizeof(OType)); typeToNumBits(output->dtype()));
} }
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType,
...@@ -1133,7 +1133,7 @@ static bool is_full_tile_1D_tensor(const Tensor *const t) { ...@@ -1133,7 +1133,7 @@ static bool is_full_tile_1D_tensor(const Tensor *const t) {
bool dimensions_supported_by_TMA(const Tensor *const t) { bool dimensions_supported_by_TMA(const Tensor *const t) {
const size_t cols = t->flat_last_dim(); const size_t cols = t->flat_last_dim();
constexpr int TMA_bytes = 16; constexpr int TMA_bytes = 16;
const int alignment_requirement = TMA_bytes / typeToSize(t->dtype()); const int alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype());
return cols % alignment_requirement == 0; return cols % alignment_requirement == 0;
} }
...@@ -1254,23 +1254,23 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ...@@ -1254,23 +1254,23 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
const Tensor *activation_input_tensor; const Tensor *activation_input_tensor;
if constexpr (IS_DBIAS || IS_DACT) { if constexpr (IS_DBIAS || IS_DACT) {
// backward - input is incoming gradient // backward - input is incoming gradient
input_tensor = reinterpret_cast<const Tensor *>(grad); input_tensor = convertNVTETensorCheck(grad);
activation_input_tensor = reinterpret_cast<const Tensor *>(input); activation_input_tensor = convertNVTETensor(input);
} else { } else {
// forward = input is activation input // forward = input is activation input
input_tensor = reinterpret_cast<const Tensor *>(input); input_tensor = convertNVTETensorCheck(input);
activation_input_tensor = nullptr; activation_input_tensor = nullptr;
} }
auto output_tensor = reinterpret_cast<Tensor *>(output); auto output_tensor = convertNVTETensorCheck(output);
auto dbias_tensor = reinterpret_cast<Tensor *>(dbias); auto dbias_tensor = convertNVTETensor(dbias);
auto workspace_tensor = reinterpret_cast<Tensor *>(workspace); auto workspace_tensor = convertNVTETensor(workspace);
const QuantizationConfig *quant_config_cpp = const QuantizationConfig *quant_config_cpp =
reinterpret_cast<const QuantizationConfig *>(quant_config); reinterpret_cast<const QuantizationConfig *>(quant_config);
// extract noop tensor from quant_config_cpp if it's not null // extract noop tensor from quant_config_cpp if it's not null
const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr; const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr;
const auto noop_tensor = noop != nullptr ? *(reinterpret_cast<const Tensor *>(noop)) : Tensor(); const auto noop_tensor = noop != nullptr ? *(convertNVTETensorCheck(noop)) : Tensor();
switch (output_tensor->scaling_mode) { switch (output_tensor->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: { case NVTE_DELAYED_TENSOR_SCALING: {
...@@ -1315,12 +1315,25 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ...@@ -1315,12 +1315,25 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D");
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f;
FP8BlockwiseRowwiseOption rowwise_option = output_tensor->has_data() FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
? FP8BlockwiseRowwiseOption::ROWWISE FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
: FP8BlockwiseRowwiseOption::NONE; if (output_tensor->has_data()) {
FP8BlockwiseColumnwiseOption columnwise_option = bool rowwise_compact = quant_config_cpp
output_tensor->has_columnwise_data() ? FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE ? quant_config_cpp->float8_block_scale_tensor_format ==
: FP8BlockwiseColumnwiseOption::NONE; Float8BlockScaleTensorFormat::COMPACT
: false;
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
}
if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = quant_config_cpp
? quant_config_cpp->float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT
: false;
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
}
quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv,
output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_scale_inv, output_tensor->data,
output_tensor->columnwise_data, epsilon, rowwise_option, output_tensor->columnwise_data, epsilon, rowwise_option,
......
...@@ -13,16 +13,42 @@ namespace transformer_engine { ...@@ -13,16 +13,42 @@ namespace transformer_engine {
namespace cuda_driver { namespace cuda_driver {
void *get_symbol(const char *symbol) { #ifndef USE_ROCM
void *entry_point; typedef cudaError_t (*VersionedGetEntryPoint)(const char *, void **, unsigned int,
unsigned long long, // NOLINT(*)
cudaDriverEntryPointQueryResult *);
typedef cudaError_t (*GetEntryPoint)(const char *, void **, unsigned long long, // NOLINT(*)
cudaDriverEntryPointQueryResult *);
#endif
void *get_symbol(const char *symbol, int cuda_version) {
#ifndef USE_ROCM
constexpr char driver_entrypoint[] = "cudaGetDriverEntryPoint";
constexpr char driver_entrypoint_versioned[] = "cudaGetDriverEntryPointByVersion";
// We link to the libcudart.so already, so can search for it in the current context
static GetEntryPoint driver_entrypoint_fun =
reinterpret_cast<GetEntryPoint>(dlsym(RTLD_DEFAULT, driver_entrypoint));
static VersionedGetEntryPoint driver_entrypoint_versioned_fun =
reinterpret_cast<VersionedGetEntryPoint>(dlsym(RTLD_DEFAULT, driver_entrypoint_versioned));
cudaDriverEntryPointQueryResult driver_result;
#endif
void *entry_point = nullptr;
#ifdef USE_ROCM #ifdef USE_ROCM
hipDriverProcAddressQueryResult driver_result; hipDriverProcAddressQueryResult driver_result;
NVTE_CHECK_CUDA(hipGetProcAddress(symbol, &entry_point, HIP_VERSION_MAJOR*100+HIP_VERSION_MINOR, 0, &driver_result)); NVTE_CHECK_CUDA(hipGetProcAddress(symbol, &entry_point, HIP_VERSION_MAJOR*100+HIP_VERSION_MINOR, 0, &driver_result));
NVTE_CHECK(driver_result == HIP_GET_PROC_ADDRESS_SUCCESS, NVTE_CHECK(driver_result == HIP_GET_PROC_ADDRESS_SUCCESS,
"Could not find CUDA driver entry point for ", symbol); "Could not find CUDA driver entry point for ", symbol);
#else #else
cudaDriverEntryPointQueryResult driver_result; if (driver_entrypoint_versioned_fun != nullptr) {
NVTE_CHECK_CUDA(cudaGetDriverEntryPoint(symbol, &entry_point, cudaEnableDefault, &driver_result)); // Found versioned entrypoint function
NVTE_CHECK_CUDA(driver_entrypoint_versioned_fun(symbol, &entry_point, cuda_version,
cudaEnableDefault, &driver_result));
} else {
NVTE_CHECK(driver_entrypoint_fun != nullptr, "Error finding the CUDA Runtime-Driver interop.");
// Versioned entrypoint function not found
NVTE_CHECK_CUDA(driver_entrypoint_fun(symbol, &entry_point, cudaEnableDefault, &driver_result));
}
NVTE_CHECK(driver_result == cudaDriverEntryPointSuccess, NVTE_CHECK(driver_result == cudaDriverEntryPointSuccess,
"Could not find CUDA driver entry point for ", symbol); "Could not find CUDA driver entry point for ", symbol);
#endif #endif
......
...@@ -19,7 +19,7 @@ namespace transformer_engine { ...@@ -19,7 +19,7 @@ namespace transformer_engine {
namespace cuda_driver { namespace cuda_driver {
/*! \brief Get pointer corresponding to symbol in CUDA driver library */ /*! \brief Get pointer corresponding to symbol in CUDA driver library */
void *get_symbol(const char *symbol); void *get_symbol(const char *symbol, int cuda_version = 12010);
/*! \brief Call function in CUDA driver library /*! \brief Call function in CUDA driver library
* *
......
...@@ -326,9 +326,9 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s ...@@ -326,9 +326,9 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
alignas(64) CUtensorMap tensor_map_output{}; alignas(64) CUtensorMap tensor_map_output{};
create_2D_tensor_map(tensor_map_input, input_data, rows, cols, SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_input, input_data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols, 0, sizeof(IType)); SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype()));
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols, 0, sizeof(OType)); SHMEM_DIM_X, cols, 0, typeToNumBits(output->dtype()));
dequantize_mxfp8_kernel<IType, OType, SCALE_DIM_Y, SCALE_DIM_X> dequantize_mxfp8_kernel<IType, OType, SCALE_DIM_Y, SCALE_DIM_X>
<<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_output, scales_ptr, <<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_output, scales_ptr,
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
#define TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
#include "multi_stream.h"
#include <transformer_engine/multi_stream.h>
#include <mutex>
#include <vector>
#include "cuda_runtime.h"
#include "logging.h"
namespace transformer_engine::detail {
cudaStream_t get_compute_stream(int idx) {
const size_t num_streams = nvte_get_num_compute_streams();
NVTE_CHECK(0 <= idx && idx < num_streams, "Invalid compute stream (requested idx ", idx,
", but there are ", num_streams, " streams)");
static std::vector<cudaStream_t> streams(num_streams);
static std::once_flag stream_init_flag;
auto init = [&]() {
for (size_t i = 0; i < num_streams; i++) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&streams[i], cudaStreamNonBlocking, -1));
}
};
std::call_once(stream_init_flag, init);
return streams[idx];
}
cudaEvent_t get_compute_stream_event(int idx) {
const size_t num_streams = nvte_get_num_compute_streams();
NVTE_CHECK(0 <= idx && idx < num_streams, "Invalid compute stream (requested idx ", idx,
", but there are ", num_streams, " streams)");
static std::vector<cudaEvent_t> events(num_streams);
static std::once_flag event_init_flag;
auto init = [&]() {
for (size_t i = 0; i < num_streams; i++) {
NVTE_CHECK_CUDA(cudaEventCreate(&events[i]));
}
};
std::call_once(event_init_flag, init);
return events[idx];
}
int get_num_compute_streams() {
static constexpr int num_compute_streams = 4;
return num_compute_streams;
}
} // namespace transformer_engine::detail
int nvte_get_num_compute_streams() { return transformer_engine::detail::get_num_compute_streams(); }
#endif // TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
#define TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
namespace transformer_engine::detail {
int get_num_compute_streams();
cudaStream_t get_compute_stream(int idx);
cudaEvent_t get_compute_stream_event(int idx);
} // namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
...@@ -155,8 +155,8 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o ...@@ -155,8 +155,8 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o
// Input matrices are divided into tiles // Input matrices are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles // Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size * 8 / typeToNumBits(type);
const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type); const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size * 8 / typeToNumBits(type);
// Add tensors to kernel argument struct // Add tensors to kernel argument struct
MultiPaddingArgs kernel_args; MultiPaddingArgs kernel_args;
...@@ -211,8 +211,8 @@ void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETe ...@@ -211,8 +211,8 @@ void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETe
std::vector<Tensor*> input_list_, output_list_; std::vector<Tensor*> input_list_, output_list_;
std::vector<int> padded_num_rows_list_; std::vector<int> padded_num_rows_list_;
for (size_t i = 0; i < num_tensors; ++i) { for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i]))); input_list_.push_back(convertNVTETensorCheck(input_list[i]));
output_list_.push_back(reinterpret_cast<Tensor*>(output_list[i])); output_list_.push_back(convertNVTETensorCheck(output_list[i]));
padded_num_rows_list_.push_back(padded_num_rows_list[i]); padded_num_rows_list_.push_back(padded_num_rows_list[i]);
} }
multi_padding(input_list_, output_list_, padded_num_rows_list_, stream); multi_padding(input_list_, output_list_, padded_num_rows_list_, stream);
......
...@@ -80,6 +80,10 @@ ...@@ -80,6 +80,10 @@
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \
pybind11::enum_<transformer_engine::Float8BlockScaleTensorFormat>( \
m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \
.value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \
.value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT); \
pybind11::enum_<transformer_engine::CommOverlapType>(m, "CommOverlapType", \ pybind11::enum_<transformer_engine::CommOverlapType>(m, "CommOverlapType", \
pybind11::module_local()) \ pybind11::module_local()) \
.value("RS", transformer_engine::CommOverlapType::RS) \ .value("RS", transformer_engine::CommOverlapType::RS) \
......
...@@ -21,6 +21,10 @@ using namespace __hip_internal; ...@@ -21,6 +21,10 @@ using namespace __hip_internal;
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
#endif
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
typedef uint16_t hip_bfloat16x2 __attribute__((ext_vector_type(2))); typedef uint16_t hip_bfloat16x2 __attribute__((ext_vector_type(2)));
#else #else
......
...@@ -12,7 +12,7 @@ from nvdlfw_inspect.registry import Registry ...@@ -12,7 +12,7 @@ from nvdlfw_inspect.registry import Registry
import torch import torch
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.pytorch.tensor import all_tensor_types from transformer_engine.pytorch.tensor import get_all_tensor_types
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
...@@ -424,7 +424,7 @@ class TransformerEngineAPI(BaseNamespaceAPI): ...@@ -424,7 +424,7 @@ class TransformerEngineAPI(BaseNamespaceAPI):
if api_name in ["inspect_tensor", "inspect_tensor_postquantize"]: if api_name in ["inspect_tensor", "inspect_tensor_postquantize"]:
assert ret is None assert ret is None
if api_name == "modify_tensor": if api_name == "modify_tensor":
assert type(ret) in all_tensor_types assert type(ret) in get_all_tensor_types()
if ( if (
type(ret) == torch.Tensor # pylint: disable=unidiomatic-typecheck type(ret) == torch.Tensor # pylint: disable=unidiomatic-typecheck
and "dtype" in kwargs and "dtype" in kwargs
...@@ -438,4 +438,4 @@ class TransformerEngineAPI(BaseNamespaceAPI): ...@@ -438,4 +438,4 @@ class TransformerEngineAPI(BaseNamespaceAPI):
def end_debug(self): def end_debug(self):
"""This function is called by the nvidia-dlframework-inspect after every debug_api.end_debug()""" """This function is called by the nvidia-dlframework-inspect after every debug_api.end_debug()"""
TEDebugState.reset() TEDebugState._reset()
...@@ -49,7 +49,7 @@ def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None): ...@@ -49,7 +49,7 @@ def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None):
fp8_dtype = tex.DType.kFloat8E5M2 fp8_dtype = tex.DType.kFloat8E5M2
amax = tensor.abs().max().float() amax = tensor.abs().max().float()
one = torch.ones(1, device=tensor.device) one = torch.ones(1, device=tensor.device)
scale = _default_sf_compute(amax, one, fp8_max) scale = _default_sf_compute(amax, one, fp8_max, 0)
quantizer = Float8Quantizer(scale, amax, fp8_dtype) quantizer = Float8Quantizer(scale, amax, fp8_dtype)
else: else:
......
...@@ -120,7 +120,6 @@ class LogFp8TensorStats(BaseLogTensorStats): ...@@ -120,7 +120,6 @@ class LogFp8TensorStats(BaseLogTensorStats):
if not rowwise: if not rowwise:
return # tensor was already seen rowwise in the other gemm return # tensor was already seen rowwise in the other gemm
tensor = tensor._data
options = ( options = (
config.get("start_step", None), config.get("start_step", None),
config.get("end_step", None), config.get("end_step", None),
......
...@@ -15,6 +15,7 @@ import transformer_engine_torch as tex ...@@ -15,6 +15,7 @@ import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor import Quantizer from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import ( from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Tensor, Float8Tensor,
Float8Quantizer,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
) )
from transformer_engine.debug.features.api import TEConfigAPIMapper from transformer_engine.debug.features.api import TEConfigAPIMapper
...@@ -39,7 +40,7 @@ def per_tensor_cast( ...@@ -39,7 +40,7 @@ def per_tensor_cast(
}, "[NVTORCH INSPECT ERROR] Only 2 FP8 types: E4M3 and E5M2 are supported in TE." }, "[NVTORCH INSPECT ERROR] Only 2 FP8 types: E4M3 and E5M2 are supported in TE."
tensor = tensor.contiguous() tensor = tensor.contiguous()
quantizer = Float8CurrentScalingQuantizer(fp8_dtype) quantizer = Float8CurrentScalingQuantizer(fp8_dtype, device=tensor.device)
if out is not None: if out is not None:
quantizer.update_quantized(tensor, out) quantizer.update_quantized(tensor, out)
...@@ -81,7 +82,6 @@ class PerTensorScaling(TEConfigAPIMapper): ...@@ -81,7 +82,6 @@ class PerTensorScaling(TEConfigAPIMapper):
transformer_engine: transformer_engine:
PerTensorScaling: PerTensorScaling:
enabled: True enabled: True
margin: 1
gemms: [dgrad] gemms: [dgrad]
tensors: [weight, activation] tensors: [weight, activation]
""" """
...@@ -118,7 +118,7 @@ class PerTensorScaling(TEConfigAPIMapper): ...@@ -118,7 +118,7 @@ class PerTensorScaling(TEConfigAPIMapper):
if key not in ["gemm", "tensor"]: if key not in ["gemm", "tensor"]:
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
assert isinstance(default_quantizer, Float8CurrentScalingQuantizer), ( assert isinstance(default_quantizer, Float8Quantizer), (
f"[NVTORCH INSPECT ERROR] Feature={self.__class__.__name__}, API=process_tensor: " f"[NVTORCH INSPECT ERROR] Feature={self.__class__.__name__}, API=process_tensor: "
"Per-tensor current scaling can be used only within `DelayedScaling` recipe autocast." "Per-tensor current scaling can be used only within `DelayedScaling` recipe autocast."
f" {layer_name}" f" {layer_name}"
......
...@@ -96,7 +96,10 @@ STATS = { ...@@ -96,7 +96,10 @@ STATS = {
"max": (torch.max, lambda buffers: max(_get(buffers, "max"))), "max": (torch.max, lambda buffers: max(_get(buffers, "max"))),
"sum": (torch.sum, lambda buffers: sum(_get(buffers, "sum"))), "sum": (torch.sum, lambda buffers: sum(_get(buffers, "sum"))),
"mean": (torch.mean, lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel"))), "mean": (torch.mean, lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel"))),
"numel": (lambda x: x.numel(), lambda buffers: sum(_get(buffers, "numel"))), "numel": (
lambda x: x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel(),
lambda buffers: sum(_get(buffers, "numel")),
),
"l1_norm": (lambda x: torch.norm(x, p=1), lambda buffers: sum(_get(buffers, "l1_norm"))), "l1_norm": (lambda x: torch.norm(x, p=1), lambda buffers: sum(_get(buffers, "l1_norm"))),
"l2_norm_square": ( "l2_norm_square": (
lambda x: torch.sum(x**2), lambda x: torch.sum(x**2),
...@@ -137,7 +140,7 @@ STATS = { ...@@ -137,7 +140,7 @@ STATS = {
- min(_get(buffers, "dynamic_range_bottom")), - min(_get(buffers, "dynamic_range_bottom")),
), ),
"underflows%": ( "underflows%": (
lambda x: (x == 0).sum() / x.numel() * 100, lambda x: (x.get_data_tensors()[0] == 0).sum() / x.get_data_tensors()[0].numel() * 100,
lambda buffers: 100 * sum(_get(buffers, "underflows_num")) / sum(_get(buffers, "numel")), lambda buffers: 100 * sum(_get(buffers, "underflows_num")) / sum(_get(buffers, "numel")),
), ),
} }
...@@ -14,10 +14,11 @@ import torch ...@@ -14,10 +14,11 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor.quantized_tensor import ( from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
Quantizer, Quantizer,
QuantizedTensorBase,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
...@@ -299,8 +300,9 @@ class DebugQuantizer(Quantizer): ...@@ -299,8 +300,9 @@ class DebugQuantizer(Quantizer):
iteration=self.iteration, iteration=self.iteration,
dtype=dtype, dtype=dtype,
) )
if columnwise_gemm_tensor.dtype != dtype: if dtype is not None:
raise ValueError("Dtype does not match the output of the modify_tensor call") if columnwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call")
if self.rowwise_tensor_plan == API_CALL_MODIFY: if self.rowwise_tensor_plan == API_CALL_MODIFY:
rowwise_gemm_tensor = debug_api.transformer_engine.modify_tensor( rowwise_gemm_tensor = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name, layer_name=self.layer_name,
...@@ -311,8 +313,9 @@ class DebugQuantizer(Quantizer): ...@@ -311,8 +313,9 @@ class DebugQuantizer(Quantizer):
iteration=self.iteration, iteration=self.iteration,
dtype=dtype, dtype=dtype,
) )
if rowwise_gemm_tensor.dtype != dtype: if dtype is not None:
raise ValueError("Dtype does not match the output of the modify_tensor call") if rowwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call")
# 3. If some tensors still are not defined we use high precision tensor. # 3. If some tensors still are not defined we use high precision tensor.
if self.rowwise_tensor_plan == HIGH_PRECISION: if self.rowwise_tensor_plan == HIGH_PRECISION:
...@@ -332,6 +335,7 @@ class DebugQuantizer(Quantizer): ...@@ -332,6 +335,7 @@ class DebugQuantizer(Quantizer):
quantizer=self, quantizer=self,
layer_name=self.layer_name, layer_name=self.layer_name,
tensor_name=self.tensor_name, tensor_name=self.tensor_name,
original_tensor=tensor,
) )
def process_gemm_output(self, tensor: torch.Tensor): def process_gemm_output(self, tensor: torch.Tensor):
...@@ -455,8 +459,12 @@ class DebugQuantizer(Quantizer): ...@@ -455,8 +459,12 @@ class DebugQuantizer(Quantizer):
return True return True
return False return False
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Probably not needed for debug quantizer"""
return None
class DebugQuantizedTensor: class DebugQuantizedTensor(QuantizedTensorBase):
""" """
Class containing quantized tensors after debug. Depending on configuration Class containing quantized tensors after debug. Depending on configuration
it can contain one or two different objects. These objects can be accessed by the method it can contain one or two different objects. These objects can be accessed by the method
...@@ -470,6 +478,7 @@ class DebugQuantizedTensor: ...@@ -470,6 +478,7 @@ class DebugQuantizedTensor:
quantizer, quantizer,
layer_name=None, layer_name=None,
tensor_name=None, tensor_name=None,
original_tensor=None,
): ):
self.rowwise_gemm_tensor = rowwise_gemm_tensor self.rowwise_gemm_tensor = rowwise_gemm_tensor
...@@ -477,6 +486,7 @@ class DebugQuantizedTensor: ...@@ -477,6 +486,7 @@ class DebugQuantizedTensor:
self.quantizer = quantizer self.quantizer = quantizer
self._layer_name = layer_name self._layer_name = layer_name
self._tensor_name = tensor_name self._tensor_name = tensor_name
self._original_tensor = original_tensor
def prepare_for_saving(self): def prepare_for_saving(self):
""" " Prepare for saving method override""" """ " Prepare for saving method override"""
...@@ -524,5 +534,5 @@ class DebugQuantizedTensor: ...@@ -524,5 +534,5 @@ class DebugQuantizedTensor:
"""Size of the tensor.""" """Size of the tensor."""
return self.rowwise_gemm_tensor.size() return self.rowwise_gemm_tensor.size()
def update_usage(self, rowwise_usage: bool, columnwise_usage: bool): def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None):
"""Update usage of the tensor.""" """Update usage of the tensor."""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment