Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
......@@ -190,14 +190,14 @@ Step 2: Cast and store to output_c
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 3: Transpose, cast and store to output_t
Step 3 (if columnwise transpose is True, GEMM_READY): Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 2 times
* What each thread does in each loop:
* 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times
* Every 8 consecutive threads do reduction and calculate the amax of each column
* 16 elements are quantized and write to output_c at a time, for a total of 2 times
* 16 elements are quantized and write to output_t at a time, for a total of 2 times
+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+
| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | |
| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | |
......@@ -209,6 +209,29 @@ Step 3: Transpose, cast and store to output_t
| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+
Step 3 (if columnwise transpose is False, COMPACT format): Skip Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 1 times
* What each thread does in each loop:
* 16 elements (in a row) are read from the shared memory, for a total of 4 rows,
* it needs 8 reads in smem to get 16 elements in a row, thread tile shape is 16x4
* Every 32 consecutive threads in a warp do reduction and calculate the amax of each column,
* so each thread will do warp shuffle 16 times to get the amax of each column
* 16 elements are quantized and write to output_t at a time, for a total of 4 times
+------16 elements-------+------16 elements-------+-----80 elements-----+------16 elements------+
| T0 | | | |
| T1 | | | |
| T2 | | | |
| T3 | | | |
| T4 | | | |
| T5 | | | |
| T6 | | | |
| T7 | | | |
| ... | | | |
| T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+
*/
// clang-format on
......@@ -231,6 +254,7 @@ constexpr int kNumThreadsLoad = kTileDim / kNVecIn;
constexpr int kNumThreadsStore = kTileDim / kNVecOut;
static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp");
static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp");
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;
template <bool kAligned, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel(
......@@ -240,9 +264,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scaling) {
bool return_rowwise = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE;
bool return_columnwise_transpose =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE;
bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE;
bool return_columnwise_gemm_ready =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
bool return_columnwise_compact =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT;
using SMemVec = Vec<IType, kNVecSMem>;
using OVec = Vec<OType, kNVecOut>;
......@@ -439,8 +465,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}
// Step 3: Transpose, cast and store to output_t
if (return_columnwise_transpose) {
// Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t
if (return_columnwise_gemm_ready) {
constexpr int c_stride =
kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory
constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem);
......@@ -554,6 +580,103 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}
}
// Step 4 (return_columnwise_compact): cast in 128x1 style and store to output, skip transpose
if (return_columnwise_compact) {
// thread tile should be 4x16, 16 means 8 smem reads
constexpr int kThreadTileRow = kTileDim / kThreadsPerWarp;
constexpr int kThreadTileCol = kNVecOut;
using RegVec = Vec<IType, kThreadTileCol>;
using RegScaleVec = Vec<CType, kThreadTileCol>;
constexpr int num_smem_reads = kNVecOut / kNVecSMem;
// c_stride will not be used here because we only have one iteration
// constexpr int c_stride = kThreadTileCol * kNumWarps / kNVecSMem;
constexpr int num_iterations =
kTileDim / (kNumWarps * kThreadTileCol); // should be only one iteration
static_assert(num_iterations == 1,
"num_iterations should be 1 for columnwise non-transpose case");
const int thr_idx_in_warp = threadIdx.x % kThreadsPerWarp;
const int warp_idx = threadIdx.x / kThreadsPerWarp;
const int r_s = thr_idx_in_warp * kThreadTileRow; // Row in shared memory
int c_s = warp_idx * num_smem_reads; // Column in shared memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory
const size_t c_g =
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory
const size_t num_ele = c_g < row_length
? min(static_cast<size_t>(kThreadTileCol), row_length - c_g)
: 0; // For not aligned case
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
RegVec reg_vec[kThreadTileRow];
RegScaleVec thr_scale;
// Step 3.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kThreadTileRow; ++i) {
int r = r_s + i;
#pragma unroll
for (int j = 0; j < num_smem_reads; ++j) {
int c = c_s + j;
SMemVec smem_vec = smem[r * kSMemCol + c];
// copy smem_vec to reg vec with its elements
#pragma unroll
for (int k = 0; k < kNVecSMem; ++k) {
reg_vec[i].data.elt[j * kNVecSMem + k] = smem_vec.data.elt[k];
}
}
}
#pragma unroll
for (int reg_idx = 0; reg_idx < kThreadTileCol; ++reg_idx) {
// Step 3.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kThreadTileRow; ++i) {
amax = fmaxf(amax, fabsf(reg_vec[i].data.elt[reg_idx]));
}
// Step 3.3: Reduce amax
const bool is_src_lane = thr_idx_in_warp == 0;
amax = warp_reduce_max<kThreadsPerWarp>(amax);
constexpr int lane_zero = 0;
amax = __shfl_sync(0xFFFFFFFF, amax, lane_zero);
// Step 3.4: Compute scale
CType scale;
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
thr_scale.data.elt[reg_idx] = scale;
// Step 3.5: Write scale_inv_t
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (c_g + reg_idx < row_length);
}
if (write_scale_inv) {
CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.y);
size_t col_idx = static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem + reg_idx;
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
// Step 3.6: Quantize
for (int row_idx = 0; row_idx < kThreadTileRow; ++row_idx) {
OType* output_g =
&output_t[(r_g + row_idx) * row_length + c_g]; // Output address in global memory
OVec output_vec;
#pragma unroll
for (int i = 0; i < kThreadTileCol; ++i) {
output_vec.data.elt[i] = static_cast<OType>(
static_cast<CType>(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i]);
}
// Step 3.7: Store output_t
if constexpr (kAligned) {
output_vec.store_to(output_g);
} else {
if (r_g + row_idx < num_rows) {
output_vec.store_to_elts(output_g, 0, num_ele);
}
}
}
// Step 3.8: Update output address, column index of shared memory
// this section shouldn't matter since we only have one iteration
}
}
}
} // namespace
......@@ -569,11 +692,6 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
const bool pow2_scale, cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise);
// assert that rowwise_option and columnwise_option are not both NONE
NVTE_CHECK(rowwise_option != FP8BlockwiseRowwiseOption::NONE ||
columnwise_option != FP8BlockwiseColumnwiseOption::NONE,
"rowwise_option and columnwise_option cannot both be NONE");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_elements = row_length;
size_t num_rows = 1;
......@@ -594,32 +712,43 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
size_t scale_t_stride_y = 0;
if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) {
NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE,
NVTE_CHECK(rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY ||
rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE_COMPACT,
"Unexpected rowwise enum value");
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2.");
size_t scale_k = scale_inv.shape[1];
scale_stride_x = scale_k;
scale_stride_y = 1;
bool rowwise_compact = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE_COMPACT;
scale_stride_x = rowwise_compact ? 1 : scale_k;
scale_stride_y = rowwise_compact ? scale_k : 1;
}
if (columnwise_option != FP8BlockwiseColumnwiseOption::NONE) {
NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE,
"Unexpected columnwise enum value");
NVTE_CHECK(output_t.shape.size() == input.shape.size(),
"output_t must have same number of dimensions as input.");
if (output_t.shape.size() > 0) {
NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t.");
for (size_t i = 1; i < output_t.shape.size(); ++i) {
NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t");
if (columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY) {
NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t.");
for (size_t i = 1; i < output_t.shape.size(); ++i) {
NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t");
}
} else {
NVTE_CHECK(columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT,
"Unexpected columnwise option enum value");
NVTE_CHECK(output_t.shape[0] == input.shape[0], "Wrong dimension 0 of output_t.");
NVTE_CHECK(
input.shape == output_t.shape,
"Input and output_t must have the same shape for columnwise non-transpose case.");
}
}
NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype.");
NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2.");
scale_t_stride_x = scale_inv_t.shape[1];
scale_t_stride_y = 1;
bool columnwise_compact = columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT;
size_t scale_t_k = scale_inv_t.shape[1];
scale_t_stride_x = columnwise_compact ? 1 : scale_t_k;
scale_t_stride_y = columnwise_compact ? scale_t_k : 1;
}
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
......
......@@ -288,14 +288,13 @@ void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stre
NVTE_API_CALL(nvte_transpose);
using namespace transformer_engine;
auto noop = Tensor();
transpose(*reinterpret_cast<const Tensor *>(input), noop, reinterpret_cast<Tensor *>(output),
stream);
transpose(*convertNVTETensorCheck(input), noop, convertNVTETensor(output), stream);
}
void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose_with_noop);
using namespace transformer_engine;
transpose(*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(noop),
reinterpret_cast<Tensor *>(output), stream);
transpose(*convertNVTETensorCheck(input), *convertNVTETensorCheck(noop),
convertNVTETensor(output), stream);
}
......@@ -386,17 +386,18 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/
workspace->data.dtype = DType::kFloat32;
} else {
// Check that workspace matches expected size
const size_t workspace_size =
const size_t workspace_size = get_buffer_size_bytes(
std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1,
std::multiplies<size_t>()) *
typeToSize(workspace->data.dtype);
const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32);
std::multiplies<size_t>()),
workspace->data.dtype);
const size_t required_size =
get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32);
NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (",
num_rows_partial_dbias, ",", row_length, "), found ())");
NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(",
num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32),
"; found dims=", workspace->data.shape,
", dtype=", typeToSize(workspace->data.dtype), ")");
", dtype=", typeToNumBits(workspace->data.dtype), " bits)");
}
}
......@@ -513,7 +514,6 @@ void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_outp
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_fp8_transpose_dbias);
using namespace transformer_engine;
fp8_transpose_dbias(
*reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(transposed_output),
reinterpret_cast<Tensor *>(dbias), reinterpret_cast<Tensor *>(workspace), stream);
fp8_transpose_dbias(*convertNVTETensorCheck(input), convertNVTETensor(transposed_output),
convertNVTETensor(dbias), convertNVTETensor(workspace), stream);
}
......@@ -10,13 +10,16 @@
#endif
#include <cuda_runtime.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/multi_stream.h>
#include <cfloat>
#include <limits>
#include <mutex>
#include <string>
#include "../common.h"
#include "../transpose/cast_transpose.h"
#include "../util/multi_stream.h"
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "cast_kernels.cuh"
......@@ -156,6 +159,45 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dequantize);
using namespace transformer_engine;
detail::dequantize_helper(*reinterpret_cast<const Tensor *>(input),
reinterpret_cast<Tensor *>(output), stream);
detail::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream);
}
void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
const NVTEQuantizationConfig quant_configs,
const size_t num_tensors, cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_quantize);
using namespace transformer_engine;
constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = false;
constexpr bool IS_ACT = false;
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;
constexpr const NVTETensor grad = nullptr;
const size_t num_streams = nvte_get_num_compute_streams();
int num_stream_used = std::min(num_streams, num_tensors);
// wait for current stream to finish
NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream));
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0)));
}
for (int i = 0; i < num_tensors; i++) {
detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
inputs[i], grad, outputs[i], dbias, workspace, nullptr,
detail::get_compute_stream(i % num_streams));
}
// record events on compute streams
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(
cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s)));
}
// wait for all compute streams to finish
for (int s = 0; s < num_stream_used; s++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
}
}
......@@ -763,19 +763,20 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
if constexpr (IS_DGATED) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X,
cols, 0, sizeof(IType));
cols, 0, typeToNumBits(gated_input.dtype()));
}
const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, 0, sizeof(IType));
SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols * 2, cols, sizeof(IType));
SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, 0, sizeof(OType));
SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, tensor_stride_elems, cols, sizeof(OType));
SHMEM_DIM_X, tensor_stride_elems, cols,
typeToNumBits(output->dtype()));
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
const size_t buff_size_aligned_in =
......@@ -862,31 +863,33 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
if constexpr (IS_DGATED) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols, 0, sizeof(IType));
SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype()));
}
const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, sizeof(IType));
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0,
typeToNumBits(gated_input.dtype()));
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, sizeof(IType));
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols,
typeToNumBits(gated_input.dtype()));
if (USE_ROWWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0,
sizeof(OType));
typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols,
SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols,
sizeof(OType));
typeToNumBits(output->dtype()));
}
if (USE_COLWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data,
rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems,
0, sizeof(OType));
0, typeToNumBits(output->dtype()));
create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data,
rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems,
cols, sizeof(OType));
cols, typeToNumBits(output->dtype()));
}
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
......@@ -1071,10 +1074,9 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input,
cudaStream_t stream) {
using namespace gated_kernels;
Tensor grad_empty_tensor;
const Tensor &grad_tensor =
IS_DGATED ? *(reinterpret_cast<const Tensor *>(grad)) : grad_empty_tensor;
const Tensor gated_input_tensor = *reinterpret_cast<const Tensor *>(gated_input);
Tensor *output_tensor = reinterpret_cast<Tensor *>(output);
const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor;
const Tensor gated_input_tensor = *convertNVTETensorCheck(gated_input);
Tensor *output_tensor = convertNVTETensorCheck(output);
if (is_supported_by_CC_100()) {
quantize_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor,
......
......@@ -904,15 +904,15 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T
alignas(64) CUtensorMap tensor_map_output{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, sizeof(IType));
FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype));
if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, sizeof(IType));
FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype));
}
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y,
FP8_SHMEM_DIM_X, cols, 0, sizeof(OType));
FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype));
cast_fp8_2D_kernel<IS_DBIAS, IS_DACT, ParamOP, OP, IType, OType>
<<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_act_input, tensor_map_output,
......@@ -1004,24 +1004,24 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
alignas(64) CUtensorMap tensor_map_output_colwise{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y,
MXFP8_SHMEM_DIM_X, cols, 0, sizeof(IType));
MXFP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype()));
if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols,
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
sizeof(IType));
typeToNumBits(input.dtype()));
}
if (use_rowwise_scaling) {
create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols,
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
sizeof(OType));
typeToNumBits(output->dtype()));
}
if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows,
cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0,
sizeof(OType));
typeToNumBits(output->dtype()));
}
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType,
......@@ -1133,7 +1133,7 @@ static bool is_full_tile_1D_tensor(const Tensor *const t) {
bool dimensions_supported_by_TMA(const Tensor *const t) {
const size_t cols = t->flat_last_dim();
constexpr int TMA_bytes = 16;
const int alignment_requirement = TMA_bytes / typeToSize(t->dtype());
const int alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype());
return cols % alignment_requirement == 0;
}
......@@ -1254,23 +1254,23 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
const Tensor *activation_input_tensor;
if constexpr (IS_DBIAS || IS_DACT) {
// backward - input is incoming gradient
input_tensor = reinterpret_cast<const Tensor *>(grad);
activation_input_tensor = reinterpret_cast<const Tensor *>(input);
input_tensor = convertNVTETensorCheck(grad);
activation_input_tensor = convertNVTETensor(input);
} else {
// forward = input is activation input
input_tensor = reinterpret_cast<const Tensor *>(input);
input_tensor = convertNVTETensorCheck(input);
activation_input_tensor = nullptr;
}
auto output_tensor = reinterpret_cast<Tensor *>(output);
auto dbias_tensor = reinterpret_cast<Tensor *>(dbias);
auto workspace_tensor = reinterpret_cast<Tensor *>(workspace);
auto output_tensor = convertNVTETensorCheck(output);
auto dbias_tensor = convertNVTETensor(dbias);
auto workspace_tensor = convertNVTETensor(workspace);
const QuantizationConfig *quant_config_cpp =
reinterpret_cast<const QuantizationConfig *>(quant_config);
// extract noop tensor from quant_config_cpp if it's not null
const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr;
const auto noop_tensor = noop != nullptr ? *(reinterpret_cast<const Tensor *>(noop)) : Tensor();
const auto noop_tensor = noop != nullptr ? *(convertNVTETensorCheck(noop)) : Tensor();
switch (output_tensor->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
......@@ -1315,12 +1315,25 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D");
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f;
FP8BlockwiseRowwiseOption rowwise_option = output_tensor->has_data()
? FP8BlockwiseRowwiseOption::ROWWISE
: FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option =
output_tensor->has_columnwise_data() ? FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE
: FP8BlockwiseColumnwiseOption::NONE;
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) {
bool rowwise_compact = quant_config_cpp
? quant_config_cpp->float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT
: false;
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
}
if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = quant_config_cpp
? quant_config_cpp->float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT
: false;
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
}
quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv,
output_tensor->columnwise_scale_inv, output_tensor->data,
output_tensor->columnwise_data, epsilon, rowwise_option,
......
......@@ -13,16 +13,42 @@ namespace transformer_engine {
namespace cuda_driver {
void *get_symbol(const char *symbol) {
void *entry_point;
#ifndef USE_ROCM
typedef cudaError_t (*VersionedGetEntryPoint)(const char *, void **, unsigned int,
unsigned long long, // NOLINT(*)
cudaDriverEntryPointQueryResult *);
typedef cudaError_t (*GetEntryPoint)(const char *, void **, unsigned long long, // NOLINT(*)
cudaDriverEntryPointQueryResult *);
#endif
void *get_symbol(const char *symbol, int cuda_version) {
#ifndef USE_ROCM
constexpr char driver_entrypoint[] = "cudaGetDriverEntryPoint";
constexpr char driver_entrypoint_versioned[] = "cudaGetDriverEntryPointByVersion";
// We link to the libcudart.so already, so can search for it in the current context
static GetEntryPoint driver_entrypoint_fun =
reinterpret_cast<GetEntryPoint>(dlsym(RTLD_DEFAULT, driver_entrypoint));
static VersionedGetEntryPoint driver_entrypoint_versioned_fun =
reinterpret_cast<VersionedGetEntryPoint>(dlsym(RTLD_DEFAULT, driver_entrypoint_versioned));
cudaDriverEntryPointQueryResult driver_result;
#endif
void *entry_point = nullptr;
#ifdef USE_ROCM
hipDriverProcAddressQueryResult driver_result;
NVTE_CHECK_CUDA(hipGetProcAddress(symbol, &entry_point, HIP_VERSION_MAJOR*100+HIP_VERSION_MINOR, 0, &driver_result));
NVTE_CHECK(driver_result == HIP_GET_PROC_ADDRESS_SUCCESS,
"Could not find CUDA driver entry point for ", symbol);
#else
cudaDriverEntryPointQueryResult driver_result;
NVTE_CHECK_CUDA(cudaGetDriverEntryPoint(symbol, &entry_point, cudaEnableDefault, &driver_result));
if (driver_entrypoint_versioned_fun != nullptr) {
// Found versioned entrypoint function
NVTE_CHECK_CUDA(driver_entrypoint_versioned_fun(symbol, &entry_point, cuda_version,
cudaEnableDefault, &driver_result));
} else {
NVTE_CHECK(driver_entrypoint_fun != nullptr, "Error finding the CUDA Runtime-Driver interop.");
// Versioned entrypoint function not found
NVTE_CHECK_CUDA(driver_entrypoint_fun(symbol, &entry_point, cudaEnableDefault, &driver_result));
}
NVTE_CHECK(driver_result == cudaDriverEntryPointSuccess,
"Could not find CUDA driver entry point for ", symbol);
#endif
......
......@@ -19,7 +19,7 @@ namespace transformer_engine {
namespace cuda_driver {
/*! \brief Get pointer corresponding to symbol in CUDA driver library */
void *get_symbol(const char *symbol);
void *get_symbol(const char *symbol, int cuda_version = 12010);
/*! \brief Call function in CUDA driver library
*
......
......@@ -326,9 +326,9 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
alignas(64) CUtensorMap tensor_map_output{};
create_2D_tensor_map(tensor_map_input, input_data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols, 0, sizeof(IType));
SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype()));
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, cols, 0, sizeof(OType));
SHMEM_DIM_X, cols, 0, typeToNumBits(output->dtype()));
dequantize_mxfp8_kernel<IType, OType, SCALE_DIM_Y, SCALE_DIM_X>
<<<grid, block, 0, stream>>>(tensor_map_input, tensor_map_output, scales_ptr,
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
#define TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
#include "multi_stream.h"
#include <transformer_engine/multi_stream.h>
#include <mutex>
#include <vector>
#include "cuda_runtime.h"
#include "logging.h"
namespace transformer_engine::detail {
cudaStream_t get_compute_stream(int idx) {
const size_t num_streams = nvte_get_num_compute_streams();
NVTE_CHECK(0 <= idx && idx < num_streams, "Invalid compute stream (requested idx ", idx,
", but there are ", num_streams, " streams)");
static std::vector<cudaStream_t> streams(num_streams);
static std::once_flag stream_init_flag;
auto init = [&]() {
for (size_t i = 0; i < num_streams; i++) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&streams[i], cudaStreamNonBlocking, -1));
}
};
std::call_once(stream_init_flag, init);
return streams[idx];
}
cudaEvent_t get_compute_stream_event(int idx) {
const size_t num_streams = nvte_get_num_compute_streams();
NVTE_CHECK(0 <= idx && idx < num_streams, "Invalid compute stream (requested idx ", idx,
", but there are ", num_streams, " streams)");
static std::vector<cudaEvent_t> events(num_streams);
static std::once_flag event_init_flag;
auto init = [&]() {
for (size_t i = 0; i < num_streams; i++) {
NVTE_CHECK_CUDA(cudaEventCreate(&events[i]));
}
};
std::call_once(event_init_flag, init);
return events[idx];
}
int get_num_compute_streams() {
static constexpr int num_compute_streams = 4;
return num_compute_streams;
}
} // namespace transformer_engine::detail
int nvte_get_num_compute_streams() { return transformer_engine::detail::get_num_compute_streams(); }
#endif // TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
#define TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
namespace transformer_engine::detail {
int get_num_compute_streams();
cudaStream_t get_compute_stream(int idx);
cudaEvent_t get_compute_stream_event(int idx);
} // namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
......@@ -155,8 +155,8 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o
// Input matrices are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size / typeToSize(type);
const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size / typeToSize(type);
const int tile_dim_m = THREADS_PER_WARP * desired_load_store_size * 8 / typeToNumBits(type);
const int tile_dim_n = THREADS_PER_WARP * desired_load_store_size * 8 / typeToNumBits(type);
// Add tensors to kernel argument struct
MultiPaddingArgs kernel_args;
......@@ -211,8 +211,8 @@ void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETe
std::vector<Tensor*> input_list_, output_list_;
std::vector<int> padded_num_rows_list_;
for (size_t i = 0; i < num_tensors; ++i) {
input_list_.push_back(reinterpret_cast<Tensor*>(const_cast<NVTETensor&>(input_list[i])));
output_list_.push_back(reinterpret_cast<Tensor*>(output_list[i]));
input_list_.push_back(convertNVTETensorCheck(input_list[i]));
output_list_.push_back(convertNVTETensorCheck(output_list[i]));
padded_num_rows_list_.push_back(padded_num_rows_list[i]);
}
multi_padding(input_list_, output_list_, padded_num_rows_list_, stream);
......
......@@ -80,6 +80,10 @@
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \
pybind11::enum_<transformer_engine::Float8BlockScaleTensorFormat>( \
m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \
.value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \
.value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT); \
pybind11::enum_<transformer_engine::CommOverlapType>(m, "CommOverlapType", \
pybind11::module_local()) \
.value("RS", transformer_engine::CommOverlapType::RS) \
......
......@@ -21,6 +21,10 @@ using namespace __hip_internal;
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
#endif
#ifdef __HIP_PLATFORM_AMD__
typedef uint16_t hip_bfloat16x2 __attribute__((ext_vector_type(2)));
#else
......
......@@ -12,7 +12,7 @@ from nvdlfw_inspect.registry import Registry
import torch
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.pytorch.tensor import all_tensor_types
from transformer_engine.pytorch.tensor import get_all_tensor_types
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
......@@ -424,7 +424,7 @@ class TransformerEngineAPI(BaseNamespaceAPI):
if api_name in ["inspect_tensor", "inspect_tensor_postquantize"]:
assert ret is None
if api_name == "modify_tensor":
assert type(ret) in all_tensor_types
assert type(ret) in get_all_tensor_types()
if (
type(ret) == torch.Tensor # pylint: disable=unidiomatic-typecheck
and "dtype" in kwargs
......@@ -438,4 +438,4 @@ class TransformerEngineAPI(BaseNamespaceAPI):
def end_debug(self):
"""This function is called by the nvidia-dlframework-inspect after every debug_api.end_debug()"""
TEDebugState.reset()
TEDebugState._reset()
......@@ -49,7 +49,7 @@ def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None):
fp8_dtype = tex.DType.kFloat8E5M2
amax = tensor.abs().max().float()
one = torch.ones(1, device=tensor.device)
scale = _default_sf_compute(amax, one, fp8_max)
scale = _default_sf_compute(amax, one, fp8_max, 0)
quantizer = Float8Quantizer(scale, amax, fp8_dtype)
else:
......
......@@ -120,7 +120,6 @@ class LogFp8TensorStats(BaseLogTensorStats):
if not rowwise:
return # tensor was already seen rowwise in the other gemm
tensor = tensor._data
options = (
config.get("start_step", None),
config.get("end_step", None),
......
......@@ -15,6 +15,7 @@ import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Tensor,
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.debug.features.api import TEConfigAPIMapper
......@@ -39,7 +40,7 @@ def per_tensor_cast(
}, "[NVTORCH INSPECT ERROR] Only 2 FP8 types: E4M3 and E5M2 are supported in TE."
tensor = tensor.contiguous()
quantizer = Float8CurrentScalingQuantizer(fp8_dtype)
quantizer = Float8CurrentScalingQuantizer(fp8_dtype, device=tensor.device)
if out is not None:
quantizer.update_quantized(tensor, out)
......@@ -81,7 +82,6 @@ class PerTensorScaling(TEConfigAPIMapper):
transformer_engine:
PerTensorScaling:
enabled: True
margin: 1
gemms: [dgrad]
tensors: [weight, activation]
"""
......@@ -118,7 +118,7 @@ class PerTensorScaling(TEConfigAPIMapper):
if key not in ["gemm", "tensor"]:
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
assert isinstance(default_quantizer, Float8CurrentScalingQuantizer), (
assert isinstance(default_quantizer, Float8Quantizer), (
f"[NVTORCH INSPECT ERROR] Feature={self.__class__.__name__}, API=process_tensor: "
"Per-tensor current scaling can be used only within `DelayedScaling` recipe autocast."
f" {layer_name}"
......
......@@ -96,7 +96,10 @@ STATS = {
"max": (torch.max, lambda buffers: max(_get(buffers, "max"))),
"sum": (torch.sum, lambda buffers: sum(_get(buffers, "sum"))),
"mean": (torch.mean, lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel"))),
"numel": (lambda x: x.numel(), lambda buffers: sum(_get(buffers, "numel"))),
"numel": (
lambda x: x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel(),
lambda buffers: sum(_get(buffers, "numel")),
),
"l1_norm": (lambda x: torch.norm(x, p=1), lambda buffers: sum(_get(buffers, "l1_norm"))),
"l2_norm_square": (
lambda x: torch.sum(x**2),
......@@ -137,7 +140,7 @@ STATS = {
- min(_get(buffers, "dynamic_range_bottom")),
),
"underflows%": (
lambda x: (x == 0).sum() / x.numel() * 100,
lambda x: (x.get_data_tensors()[0] == 0).sum() / x.get_data_tensors()[0].numel() * 100,
lambda buffers: 100 * sum(_get(buffers, "underflows_num")) / sum(_get(buffers, "numel")),
),
}
......@@ -14,10 +14,11 @@ import torch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
QuantizedTensorBase,
prepare_for_saving,
restore_from_saved,
)
......@@ -299,8 +300,9 @@ class DebugQuantizer(Quantizer):
iteration=self.iteration,
dtype=dtype,
)
if columnwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call")
if dtype is not None:
if columnwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call")
if self.rowwise_tensor_plan == API_CALL_MODIFY:
rowwise_gemm_tensor = debug_api.transformer_engine.modify_tensor(
layer_name=self.layer_name,
......@@ -311,8 +313,9 @@ class DebugQuantizer(Quantizer):
iteration=self.iteration,
dtype=dtype,
)
if rowwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call")
if dtype is not None:
if rowwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call")
# 3. If some tensors still are not defined we use high precision tensor.
if self.rowwise_tensor_plan == HIGH_PRECISION:
......@@ -332,6 +335,7 @@ class DebugQuantizer(Quantizer):
quantizer=self,
layer_name=self.layer_name,
tensor_name=self.tensor_name,
original_tensor=tensor,
)
def process_gemm_output(self, tensor: torch.Tensor):
......@@ -455,8 +459,12 @@ class DebugQuantizer(Quantizer):
return True
return False
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Probably not needed for debug quantizer"""
return None
class DebugQuantizedTensor:
class DebugQuantizedTensor(QuantizedTensorBase):
"""
Class containing quantized tensors after debug. Depending on configuration
it can contain one or two different objects. These objects can be accessed by the method
......@@ -470,6 +478,7 @@ class DebugQuantizedTensor:
quantizer,
layer_name=None,
tensor_name=None,
original_tensor=None,
):
self.rowwise_gemm_tensor = rowwise_gemm_tensor
......@@ -477,6 +486,7 @@ class DebugQuantizedTensor:
self.quantizer = quantizer
self._layer_name = layer_name
self._tensor_name = tensor_name
self._original_tensor = original_tensor
def prepare_for_saving(self):
""" " Prepare for saving method override"""
......@@ -524,5 +534,5 @@ class DebugQuantizedTensor:
"""Size of the tensor."""
return self.rowwise_gemm_tensor.size()
def update_usage(self, rowwise_usage: bool, columnwise_usage: bool):
def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None):
"""Update usage of the tensor."""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment