Commit 800f4dff authored by yuguo's avatar yuguo
Browse files
parents 1e018a45 9fe13a33
from collections.abc import Iterable
import io
from typing import Any, Dict, List, Tuple, Union, Optional
import pytest
import torch
import transformer_engine as te
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8Tensor,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
import transformer_engine_torch as tex
from references.ref_per_tensor_cs import ref_per_tensor_cs_cast
from transformer_engine.pytorch.triton.per_token_group_quant import per_token_quant_int8, per_token_quant_int8_v2, per_token_quant_fp8_to_int8, per_token_quant_fp8_to_int8_v2, channelwise_dequantize, channelwise_dequantize_transA, channelwise_dequantize_transB, per_token_quant_fp8_to_int8_opt
import time
# TN
m = 4096
k = 4096
n = 4096
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
device = "cuda"
out_dtype = torch.int32
# Allocate cuBLAS workspace
workspace_size = 128
workspace = torch.empty(128, dtype=torch.uint8, device=device)
out_quantizer = None
accumulate = False
use_gelu = False
use_bias = False
bias = None
use_grad = False
assert not (use_gelu and use_bias), "Bias and GELU not supported by GEMM"
aux_tensor = torch.empty((m, n), dtype=out_dtype, device=device) if use_gelu else None
out = torch.empty((m, n), dtype=out_dtype, device=device) if accumulate else None
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
use_split_accumulator = False
# bf16 to int8
# transa = True
# transb = False
# x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
# w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
# bf16_out = torch.matmul(x_bf16, w_bf16.t())
# x_int8, x_scales = per_token_quant_int8(x_bf16)
# w_int8, w_scales = per_token_quant_int8(w_bf16)
# # print("x_int8: ", x_int8)
# # print("w_int8: ", w_int8)
# # cuBLAS GEMM
# # return type is out, bias_grad, gelu_input, extra_output
# # We are just capturing out.
# y_int32 = tex.generic_gemm(
# w_int8,
# transa,
# x_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# # y_int32 = torch._int_mm(x_int8, w_int8.t())
# # print("y_int32: ", y_int32)
# output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
# # print("out_scales.shape: ", out_scales.shape)
# # print("out_scales: ", out_scales)
# # print("bf16_out: ", bf16_out)
# # print("output: ", output)
# # NN
# transa = False
# transb = False
# dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
# w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
# bf16_dx = torch.matmul(dy_bf16, w_bf16)
# dy_int8, dy_scales = per_token_quant_int8(dy_bf16)
# w_int8, w_scales = per_token_quant_int8_v2(w_bf16)
# # print("dy_scales.shape: ", dy_scales.shape)
# # print("w_scales.shape: ", w_scales.shape)
# # print("dy_int8: ", dy_int8)
# # print("w_int8: ", w_int8)
# # cuBLAS GEMM
# # return type is out, bias_grad, gelu_input, extra_output
# # We are just capturing out.
# dx_int32 = tex.generic_gemm(
# w_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# # dx_int32 = torch._int_mm(dy_int8, w_int8)
# # print("dx_int32: ", dx_int32)
# dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# # print("dx_scales.shape: ", dx_scales.shape)
# # print("dx_scales: ", dx_scales)
# # print("bf16_dx: ", bf16_dx)
# # print("dx: ", dx)
# # NT
# transa = False
# transb = True
# dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
# x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
# bf16_dw = torch.matmul(dy_bf16.t(), x_bf16)
# dy_int8, dy_scales = per_token_quant_int8_v2(dy_bf16)
# x_int8, x_scales = per_token_quant_int8_v2(x_bf16)
# # cuBLAS GEMM
# # return type is out, bias_grad, gelu_input, extra_output
# # We are just capturing out.
# dw_int32 = tex.generic_gemm(
# x_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# # print("bf16_dw: ", bf16_dw)
# # print("dw: ", dw)
# fp8 to int8
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
device="cuda",
force_pow_2_scales=False,
amax_epsilon=0.0,
)
# current scaling
def to_float8_CS(
tensor: torch.Tensor,
fp8_dtype: tex.DType = tex.DType.kFloat8E5M2,
return_transpose: bool = False,
force_pow_2_scales: bool = False,
amax_epsilon: float = 0.0,
) -> Float8Tensor:
"""Cast tensor to FP8"""
if return_transpose:
quantizer.set_usage(rowwise=True, columnwise=True)
else:
quantizer.set_usage(rowwise=True, columnwise=False)
return quantizer(tensor)
# TN
transa = True
transb = False
x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
bf16_out = torch.matmul(x_bf16, w_bf16.t())
torch.cuda.synchronize()
start = time.time()
for i in range(20):
bf16_out = torch.matmul(x_bf16, w_bf16.t())
torch.cuda.synchronize()
end = time.time()
# x_int8, x_scales = per_token_quant_int8(x_bf16)
# w_int8, w_scales = per_token_quant_int8(w_bf16)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# Cast to FP8 and back
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e5m2))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e5m2))
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
y_int32 = tex.generic_gemm(
w_int8,
transa,
x_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
# y_int32 = torch._int_mm(x_int8, w_int8.t())
# print("y_int32: ", y_int32)
output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
# print("out_scales.shape: ", out_scales.shape)
# print("out_scales: ", out_scales)
print("bf16_out: ", bf16_out)
print("output: ", output)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
y_int32 = tex.generic_gemm(
w_int8,
transa,
x_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
torch.cuda.synchronize()
end = time.time()
# NN
# transa = True
transa = False
transb = False
dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
bf16_dx = torch.matmul(dy_bf16, w_bf16)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
bf16_dx = torch.matmul(dy_bf16, w_bf16)
torch.cuda.synchronize()
end = time.time()
# Cast to FP8 and back
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# print("dy_scales.shape: ", dy_scales.shape)
# print("w_scales.shape: ", w_scales.shape)
# print("dy_int8: ", dy_int8)
# print("w_int8: ", w_int8)
# print("w_scales: ", w_scales)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
dx_int32 = tex.generic_gemm(
w_int8,
transa,
dy_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
# dx_int32 = torch._int_mm(dy_int8, w_int8)
# print("dx_int32: ", dx_int32)
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
# print("dx_scales.shape: ", dx_scales.shape)
# print("dx_scales: ", dx_scales)
print("bf16_dx: ", bf16_dx)
print("dx: ", dx)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
dx_int32 = tex.generic_gemm(
w_int8,
transa,
dy_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
torch.cuda.synchronize()
end = time.time()
# NT
# transa = True
# transb = False
transa = False
transb = True
dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
bf16_dw = torch.matmul(dy_bf16.t(), x_bf16)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
bf16_dw = torch.matmul(dy_bf16.t(), x_bf16)
torch.cuda.synchronize()
end = time.time()
# Cast to FP8 and back
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
dw_int32 = tex.generic_gemm(
x_int8,
transa,
dy_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
print("bf16_dw: ", bf16_dw)
print("dw: ", dw)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dw_int32 = tex.generic_gemm(
x_int8,
transa,
dy_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
torch.cuda.synchronize()
end = time.time()
\ No newline at end of file
......@@ -678,10 +678,14 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
NVTE_ERROR("TT layout not allowed.");
}
const bool use_int8 = is_int8_dtype(inputA->data.dtype) ||
is_int8_dtype(inputB->data.dtype);
const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)) {
NVTE_CHECK(!use_int8, "Int8 gemm just surpport pure int8 gemm without any epilogue.");
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, transa, transb, grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 0, 0,
false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
......
......@@ -9,6 +9,22 @@
namespace {
template<typename T>
void printTensor(const std::string& str, const T* devTensor, size_t size) {
T* hostTensor;
hostTensor = (T*)malloc(size * sizeof(T));
hipMemcpy(hostTensor, devTensor, size * sizeof(T), hipMemcpyDeviceToHost);
std::cout << str << ": ";
for(int i; i<size; i++) {
if (i % 16 == 0) {
std::cout << std::endl;
}
std::cout << static_cast<float>(hostTensor[i]) << ", ";
}
std::cout << str << ": finish" << std::endl;
free(hostTensor);
}
hipblasDatatype_t get_hip_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
......@@ -17,7 +33,11 @@ hipblasDatatype_t get_hip_dtype(const transformer_engine::DType t) {
case DType::kFloat32:
return HIPBLAS_R_32F;
case DType::kBFloat16:
return HIPBLAS_R_16B;
return HIPBLAS_R_16B;
case DType::kInt8:
return HIPBLAS_R_8I;
case DType::kInt32:
return HIPBLAS_R_32I;
default:
NVTE_ERROR("Invalid type");
}
......@@ -82,6 +102,16 @@ void hipblas_gemm(const Tensor *inputA,
float one = 1.0f;
float zero = 0.0f;
float beta = accumulate ? one : zero;
int int_one = 1;
int int_zero = 0;
int int_beta = int_zero;
bool use_int8 = false;
if ((A_type == HIPBLAS_R_8I) && (B_type == HIPBLAS_R_8I) && (D_type == HIPBLAS_R_32I)) {
NVTE_CHECK(!accumulate, "Int8 gemm not support accumulate.");
use_int8 = true;
computeType = HIPBLAS_R_32I;
}
hipblasSetStream(handle, stream);
// execute multiply
......@@ -92,20 +122,20 @@ void hipblas_gemm(const Tensor *inputA,
m,
n,
k,
static_cast<const void*>(&one),
use_int8 ? static_cast<const void*>(&int_one) : static_cast<const void*>(&one),
A,
A_type,
lda,
B,
B_type,
ldb,
static_cast<const void*>(&beta),
use_int8 ? static_cast<const void*>(&int_beta) : static_cast<const void*>(&beta),
D,
D_type,
ldd,
computeType,
HIPBLAS_GEMM_DEFAULT);
// printTensor<int32_t>("D_tensor: ", reinterpret_cast<int32_t*>(D), 10);
if (status != HIPBLAS_STATUS_SUCCESS) {
NVTE_ERROR("hipblasGemmEx execution failed");
}
......
......@@ -84,6 +84,8 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
*/
void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream);
void nvte_compute_channel_colwise_amax(const NVTETensor input, NVTETensor output, const NVTETensor fp8_scale, cudaStream_t stream);
/*! \brief Update an FP8 tensor's scale based on its amax.
*
* This is only supported for FP8 tensors with per-tensor scaling.
......
......@@ -416,6 +416,15 @@ enum class DType {
kNumTypes
};
/*! \brief Check if TE datatype is INT8
*
* Return true if TE datatype is INT8
* \param[in] DType TE Datatype of interest
*/
inline bool is_int8_dtype(const DType t) {
return t == DType::kInt8;
}
/*! \brief Check if TE datatype is FP8
*
* Return true if TE datatype is FP8
......
......@@ -16,7 +16,10 @@
#include "recipe_common.cuh"
#ifdef __HIP_PLATFORM_AMD__
#include <hipcub/hipcub.hpp>
using __nv_bfloat16 = __hip_bfloat16;
constexpr int kColwiseReduceTileSize = 32;
constexpr int THREADS_PER_BLOCK = 1024;
#endif
namespace transformer_engine {
......@@ -61,6 +64,65 @@ __launch_bounds__(amax_kernel_threads) __global__
}
}
template <typename T>
__inline__ __device__ T WarpReduceMax(T val, int max = 32) {
for (int offset = max; offset > 0; offset >>= 1) {
val = fmaxf(__shfl_down(val, offset), val);
}
return val;
}
template <int nvec, typename InputType>
__launch_bounds__(1024) __global__
void channel_colwise_amax_kernel(float *dst, const InputType *src, const float *fp8_scale, int M, int N) {
__shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize];
const int j = blockIdx.x * blockDim.x + threadIdx.x;
float channel_amax = 0.f;
float scale = fp8_scale[0];
if (j < N) {
for (int i = threadIdx.y; i < M; i += blockDim.y) {
channel_amax = fmaxf(fabsf(static_cast<float>(src[i * N + j]) * scale), channel_amax);
}
}
g_shared[threadIdx.y][threadIdx.x] = channel_amax;
__syncthreads();
float amax = g_shared[threadIdx.x][threadIdx.y];
amax = WarpReduceMax<float>(amax, kColwiseReduceTileSize / 2);
if (threadIdx.x == 0) {
const int j = blockIdx.x * blockDim.x + threadIdx.y;
if (j < N) {
dst[j] = static_cast<float>(amax) / 127.0; // scales
}
}
}
template <typename InputType>
__launch_bounds__(THREADS_PER_BLOCK) __global__
void channel_colwise_amax_kernel_v2(const InputType* in, float* out, const float* fp8_scale, int m, int n) {
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage;
float scale = fp8_scale[0];
int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK);
int THREADS_PER_COL = BLOCKS_PER_COL * THREADS_PER_BLOCK;
int idx = threadIdx.x + blockIdx.x * blockDim.x;
int col_idx = idx / THREADS_PER_COL;
int row_idx = idx % THREADS_PER_COL;
float thread_data;
if (row_idx < m)
thread_data = fabsf((float)in[row_idx * n + col_idx] * scale);
float local_amax;
if (row_idx < (BLOCKS_PER_COL-1) * THREADS_PER_BLOCK) {
local_amax = BlockReduce(block_temp_storage).Reduce(thread_data, hipcub::Max());
} else {
local_amax = BlockReduce(block_temp_storage).Reduce(thread_data, hipcub::Max(), m - (BLOCKS_PER_COL - 1) * THREADS_PER_BLOCK);
}
if (threadIdx.x == 0) {
atomicMax(&out[col_idx], local_amax);
out[col_idx] = out[col_idx] / 127.0;
}
}
template <int nvec, typename InputType>
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
// Zero out amax so we can update with atomic max
......@@ -103,6 +165,26 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <int nvec, typename InputType>
void launch_channel_colwise_amax_kernel(const InputType *input, float *amax, const float *fp8_scale, const size_t M, const size_t N, cudaStream_t stream) {
// Zero out amax so we can update with atomic max
cudaMemsetAsync(amax, 0, N * sizeof(float), stream);
// Launch kernel
int B =(N - 1) / kColwiseReduceTileSize + 1;
channel_colwise_amax_kernel<nvec, InputType><<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(amax, input, fp8_scale, M, N);
// Launch kernel v2
// dim3 block, grid;
// int BLOCKS_PER_COL = ceil(float(M) / THREADS_PER_BLOCK);
// block.x = THREADS_PER_BLOCK;
// grid.x = BLOCKS_PER_COL * N;
// hipLaunchKernelGGL((channel_colwise_amax_kernel_v2<InputType>), dim3(grid), dim3(block), 0, stream, input, amax, fp8_scale, M, N);
// Check results
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace
} // namespace transformer_engine
......@@ -150,6 +232,40 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
stream);); // NOLINT(*)
}
void nvte_compute_channel_colwise_amax(const NVTETensor input_, const NVTETensor output_, const NVTETensor fp8_scale_, cudaStream_t stream) {
NVTE_API_CALL(nvte_compute_channel_colwise_amax);
using namespace transformer_engine;
// Check input tensor
NVTE_CHECK(input_ != nullptr, "Invalid input tensor (got NULL)");
NVTE_CHECK(fp8_scale_ != nullptr, "Invalid fp8 scale tensor (got NULL)");
const auto &input = *convertNVTETensorCheck(input_);
const auto &fp8_scale = *convertNVTETensorCheck(fp8_scale_);
NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor for amax computation must unquantized, "
"but got scaling_mode=",
to_string(input.scaling_mode));
NVTE_CHECK(input.data.dptr != nullptr, "Input tensor for amax computation has no data");
// Check output tensor
NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
auto &output = *convertNVTETensorCheck(output_);
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling, "
"but got scaling_mode=",
to_string(output.scaling_mode));
CheckOutputTensor(output, "output_compute_amax", true);
// Compute amax
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
launch_channel_colwise_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<float *>(output.data.dptr), reinterpret_cast<const float *>(fp8_scale.data.dptr),
input.data.shape[0],
input.data.shape[1],
stream);); // NOLINT(*)
}
namespace transformer_engine {
namespace {
......
......@@ -253,6 +253,7 @@ inline size_t typeToNumBits(transformer_engine::DType t) {
case transformer_engine::DType::kByte:
case transformer_engine::DType::kFloat8E4M3:
case transformer_engine::DType::kFloat8E5M2:
case transformer_engine::DType::kInt8:
return 8;
case transformer_engine::DType::kFloat4E2M1:
return 4;
......@@ -263,6 +264,8 @@ inline size_t typeToNumBits(transformer_engine::DType t) {
inline at::ScalarType GetATenDType(transformer_engine::DType t) {
switch (t) {
case transformer_engine::DType::kInt8:
return torch::kInt8;
case transformer_engine::DType::kInt16:
return torch::kInt16;
case transformer_engine::DType::kInt32:
......@@ -308,6 +311,8 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
return transformer_engine::DType::kInt32;
case torch::kInt64:
return transformer_engine::DType::kInt64;
case torch::kInt8:
return transformer_engine::DType::kInt8;
default:
std::cout << "Type: " << static_cast<int>(t) << std::endl;
NVTE_ERROR("Invalid type");
......
......@@ -259,6 +259,8 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads
void compute_amax(const at::Tensor &tensor, at::Tensor &amax);
void compute_channel_colwise_amax(const at::Tensor &tensor, at::Tensor &amax, at::Tensor &fp8_scale);
void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
std::vector<at::Tensor> amax_histories,
std::vector<at::Tensor> scales,
......
......@@ -216,6 +216,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("compute_amax", &transformer_engine::pytorch::compute_amax,
"Compute absolute max value in tensor", py::arg("input"), py::arg("amax"),
py::call_guard<py::gil_scoped_release>());
m.def("compute_channel_colwise_amax", &transformer_engine::pytorch::compute_channel_colwise_amax,
"Compute colwise absolute max value in tensor", py::arg("input"), py::arg("amax"), py::arg("fp8_scale"),
py::call_guard<py::gil_scoped_release>());
m.def("fused_amax_and_scale_update_after_reduction",
&transformer_engine::pytorch::fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale/scale_inv after reduction",
......
......@@ -28,6 +28,19 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream());
}
void compute_channel_colwise_amax(const at::Tensor& tensor, at::Tensor& amax, at::Tensor& fp8_scale) {
auto input_tensor = tensor.contiguous();
const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
const TensorWrapper& te_fp8_scale = makeTransformerEngineTensor(fp8_scale);
TORCH_CHECK(amax.scalar_type() == at::kFloat, "amax must be a float tensor");
TORCH_CHECK(fp8_scale.numel() == 1, "fp8_scale must have exactly one element");
TORCH_CHECK(te_input.shape().ndim == 2, "input ndim must be 2");
const TensorWrapper& te_amax = makeTransformerEngineTensor(amax);
nvte_compute_channel_colwise_amax(te_input.data(), te_amax.data(), te_fp8_scale.data(), at::cuda::getCurrentCUDAStream());
}
void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reduction_buffer,
std::vector<at::Tensor> amax_histories,
std::vector<at::Tensor> scales,
......
......@@ -10,7 +10,326 @@ import pandas as pd
import logging
import math
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import transformer_engine_torch as tex
from triton.language.extra import libdevice
@triton.jit
def _per_token_quant_int8(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
BLOCK: tl.constexpr,
):
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
row_id = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
def per_token_quant_int8(x):
M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
assert x.is_contiguous()
_per_token_quant_int8[(M, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
@triton.jit
def _per_token_quant_int8_v2(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
M,
BLOCK: tl.constexpr,
):
col_id = tl.program_id(0)
rows = tl.arange(0, BLOCK)
mask = rows < M
x = tl.load(x_ptr + rows * stride_x + col_id, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + rows * stride_xq + col_id, x_q, mask=mask)
tl.store(scale_ptr + col_id, scale_x)
def per_token_quant_int8_v2(x):
assert x.is_contiguous()
x = x.view(-1, x.shape[-1])
M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty((1, x.shape[-1]),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(M)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
_per_token_quant_int8_v2[(N, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
M=M,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
@triton.jit
def _per_token_quant_fp8_to_int8(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
fp8_scale_inv,
BLOCK: tl.constexpr,
):
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
row_id = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
fp8_scale = tl.load(fp8_scale_inv)
x = x * fp8_scale
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
def per_token_quant_fp8_to_int8(x, fp8_scale_inv, inplace=False):
M = x.numel() // x.shape[-1]
N = x.shape[-1]
if inplace:
x_q = x.view(dtype=torch.int8)
else:
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
assert x.is_contiguous()
_per_token_quant_fp8_to_int8[(M, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
BLOCK=BLOCK,
fp8_scale_inv=fp8_scale_inv,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
@triton.jit
def _per_token_quant_fp8_to_int8_opt(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
fp8_scale_inv,
BLOCK: tl.constexpr,
):
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
row_id = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
scales = tl.load(scale_ptr + cols, mask=mask,
other=0.0).to(tl.float32)
fp8_scale = tl.load(fp8_scale_inv)
x = x * fp8_scale
x_q = x / scales
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
def per_token_quant_fp8_to_int8_opt(x, fp8_scale_inv, inplace=False):
assert x.is_contiguous()
x = x.view(-1, x.shape[-1])
M = x.numel() // x.shape[-1]
N = x.shape[-1]
if inplace:
x_q = x.view(dtype=torch.int8)
else:
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty((1, x.shape[-1]),
device=x.device,
dtype=torch.float32)
tex.compute_channel_colwise_amax(x, scales, fp8_scale_inv)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
_per_token_quant_fp8_to_int8_opt[(M, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
BLOCK=BLOCK,
fp8_scale_inv=fp8_scale_inv,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
@triton.jit
def _per_token_quant_fp8_to_int8_v2(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
M,
fp8_scale_inv,
BLOCK: tl.constexpr,
BLOCK_N: tl.constexpr,
):
col_id = tl.program_id(0)
rows = tl.arange(0, BLOCK)
cols = tl.arange(0, BLOCK_N)
offset_cols = col_id * BLOCK_N + cols
mask = rows[:,None] < M
x = tl.load(x_ptr + rows[:,None] * stride_x + offset_cols[None,:], mask=mask, other=0.0).to(tl.float32)
fp8_scale = tl.load(fp8_scale_inv)
x = x * fp8_scale
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + rows[:,None] * stride_xq + offset_cols[None,:], x_q, mask=mask)
tl.store(scale_ptr + offset_cols, scale_x)
def per_token_quant_fp8_to_int8_v2(x, fp8_scale_inv, inplace=False):
assert x.is_contiguous()
x = x.view(-1, x.shape[-1])
M = x.numel() // x.shape[-1]
N = x.shape[-1]
if inplace:
x_q = x.view(dtype=torch.int8)
else:
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty((1, x.shape[-1]),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(M)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
_per_token_quant_fp8_to_int8_v2[(N//32, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
M=M,
fp8_scale_inv=fp8_scale_inv,
BLOCK=BLOCK,
BLOCK_N=triton.next_power_of_2(32),
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
@torch.compile(mode="max-autotune-no-cudagraphs")
def channelwise_dequantize(A, B, C):
out_scales = A * B
return (out_scales * C.to(dtype=torch.float32)).to(torch.bfloat16)
@torch.compile(mode="max-autotune-no-cudagraphs")
def channelwise_dequantize_transA(A, B, C):
out_scales = A.T * B
return (out_scales * C.to(dtype=torch.float32)).to(torch.bfloat16)
@torch.compile(mode="max-autotune-no-cudagraphs")
def channelwise_dequantize_transA_float(A, B, C):
out_scales = A.T * B
return out_scales * C.to(dtype=torch.float32)
@torch.compile(mode="max-autotune-no-cudagraphs")
def channelwise_dequantize_transB(A, B, C):
out_scales = A * B.T
return (out_scales * C.to(dtype=torch.float32)).to(torch.bfloat16)
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment