Commit 9fe13a33 authored by yuguo's avatar yuguo
Browse files

[DCU] Preliminary support for channelwise

parent 40a4d896
from collections.abc import Iterable
import io
from typing import Any, Dict, List, Tuple, Union, Optional
import pytest
import torch
import transformer_engine as te
import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8Tensor,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
import transformer_engine_torch as tex
from references.ref_per_tensor_cs import ref_per_tensor_cs_cast
from transformer_engine.pytorch.triton.per_token_group_quant import per_token_quant_int8, per_token_quant_int8_v2, per_token_quant_fp8_to_int8, per_token_quant_fp8_to_int8_v2, channelwise_dequantize, channelwise_dequantize_transA, channelwise_dequantize_transB, per_token_quant_fp8_to_int8_opt
import time
# TN
m = 4096
k = 4096
n = 4096
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
device = "cuda"
out_dtype = torch.int32
# Allocate cuBLAS workspace
workspace_size = 128
workspace = torch.empty(128, dtype=torch.uint8, device=device)
out_quantizer = None
accumulate = False
use_gelu = False
use_bias = False
bias = None
use_grad = False
assert not (use_gelu and use_bias), "Bias and GELU not supported by GEMM"
aux_tensor = torch.empty((m, n), dtype=out_dtype, device=device) if use_gelu else None
out = torch.empty((m, n), dtype=out_dtype, device=device) if accumulate else None
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
use_split_accumulator = False
# bf16 to int8
# transa = True
# transb = False
# x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
# w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
# bf16_out = torch.matmul(x_bf16, w_bf16.t())
# x_int8, x_scales = per_token_quant_int8(x_bf16)
# w_int8, w_scales = per_token_quant_int8(w_bf16)
# # print("x_int8: ", x_int8)
# # print("w_int8: ", w_int8)
# # cuBLAS GEMM
# # return type is out, bias_grad, gelu_input, extra_output
# # We are just capturing out.
# y_int32 = tex.generic_gemm(
# w_int8,
# transa,
# x_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# # y_int32 = torch._int_mm(x_int8, w_int8.t())
# # print("y_int32: ", y_int32)
# output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
# # print("out_scales.shape: ", out_scales.shape)
# # print("out_scales: ", out_scales)
# # print("bf16_out: ", bf16_out)
# # print("output: ", output)
# # NN
# transa = False
# transb = False
# dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
# w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
# bf16_dx = torch.matmul(dy_bf16, w_bf16)
# dy_int8, dy_scales = per_token_quant_int8(dy_bf16)
# w_int8, w_scales = per_token_quant_int8_v2(w_bf16)
# # print("dy_scales.shape: ", dy_scales.shape)
# # print("w_scales.shape: ", w_scales.shape)
# # print("dy_int8: ", dy_int8)
# # print("w_int8: ", w_int8)
# # cuBLAS GEMM
# # return type is out, bias_grad, gelu_input, extra_output
# # We are just capturing out.
# dx_int32 = tex.generic_gemm(
# w_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# # dx_int32 = torch._int_mm(dy_int8, w_int8)
# # print("dx_int32: ", dx_int32)
# dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# # print("dx_scales.shape: ", dx_scales.shape)
# # print("dx_scales: ", dx_scales)
# # print("bf16_dx: ", bf16_dx)
# # print("dx: ", dx)
# # NT
# transa = False
# transb = True
# dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
# x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
# bf16_dw = torch.matmul(dy_bf16.t(), x_bf16)
# dy_int8, dy_scales = per_token_quant_int8_v2(dy_bf16)
# x_int8, x_scales = per_token_quant_int8_v2(x_bf16)
# # cuBLAS GEMM
# # return type is out, bias_grad, gelu_input, extra_output
# # We are just capturing out.
# dw_int32 = tex.generic_gemm(
# x_int8,
# transa,
# dy_int8,
# transb,
# out,
# out_quantizer,
# TE_DType[out_dtype],
# bias,
# bias_dtype,
# use_gelu,
# aux_tensor,
# use_grad,
# workspace,
# workspace.shape[0],
# accumulate,
# use_split_accumulator,
# )[0]
# dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# # print("bf16_dw: ", bf16_dw)
# # print("dw: ", dw)
# fp8 to int8
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E5M2,
device="cuda",
force_pow_2_scales=False,
amax_epsilon=0.0,
)
# current scaling
def to_float8_CS(
tensor: torch.Tensor,
fp8_dtype: tex.DType = tex.DType.kFloat8E5M2,
return_transpose: bool = False,
force_pow_2_scales: bool = False,
amax_epsilon: float = 0.0,
) -> Float8Tensor:
"""Cast tensor to FP8"""
if return_transpose:
quantizer.set_usage(rowwise=True, columnwise=True)
else:
quantizer.set_usage(rowwise=True, columnwise=False)
return quantizer(tensor)
# TN
transa = True
transb = False
x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
bf16_out = torch.matmul(x_bf16, w_bf16.t())
torch.cuda.synchronize()
start = time.time()
for i in range(20):
bf16_out = torch.matmul(x_bf16, w_bf16.t())
torch.cuda.synchronize()
end = time.time()
# x_int8, x_scales = per_token_quant_int8(x_bf16)
# w_int8, w_scales = per_token_quant_int8(w_bf16)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# Cast to FP8 and back
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e5m2))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e5m2))
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
y_int32 = tex.generic_gemm(
w_int8,
transa,
x_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
# y_int32 = torch._int_mm(x_int8, w_int8.t())
# print("y_int32: ", y_int32)
output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
# print("out_scales.shape: ", out_scales.shape)
# print("out_scales: ", out_scales)
print("bf16_out: ", bf16_out)
print("output: ", output)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
y_int32 = tex.generic_gemm(
w_int8,
transa,
x_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
torch.cuda.synchronize()
end = time.time()
# NN
# transa = True
transa = False
transb = False
dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
bf16_dx = torch.matmul(dy_bf16, w_bf16)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
bf16_dx = torch.matmul(dy_bf16, w_bf16)
torch.cuda.synchronize()
end = time.time()
# Cast to FP8 and back
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# print("dy_scales.shape: ", dy_scales.shape)
# print("w_scales.shape: ", w_scales.shape)
# print("dy_int8: ", dy_int8)
# print("w_int8: ", w_int8)
# print("w_scales: ", w_scales)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
dx_int32 = tex.generic_gemm(
w_int8,
transa,
dy_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
# dx_int32 = torch._int_mm(dy_int8, w_int8)
# print("dx_int32: ", dx_int32)
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
# print("dx_scales.shape: ", dx_scales.shape)
# print("dx_scales: ", dx_scales)
print("bf16_dx: ", bf16_dx)
print("dx: ", dx)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8_v2(w_fp8._data.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
# w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._transpose.view(dtype=torch.float8_e5m2), w_fp8._scale_inv, False)
dx_int32 = tex.generic_gemm(
w_int8,
transa,
dy_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
torch.cuda.synchronize()
end = time.time()
# NT
# transa = True
# transb = False
transa = False
transb = True
dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
bf16_dw = torch.matmul(dy_bf16.t(), x_bf16)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
bf16_dw = torch.matmul(dy_bf16.t(), x_bf16)
torch.cuda.synchronize()
end = time.time()
# Cast to FP8 and back
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
dw_int32 = tex.generic_gemm(
x_int8,
transa,
dy_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
print("bf16_dw: ", bf16_dw)
print("dw: ", dw)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
# dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
# x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2, return_transpose=True)
dy_fp8 = to_float8_CS(dy_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E5M2)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8_v2(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8_v2(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(dy_fp8._data.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(x_fp8._data.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
# dy_int8, dy_scales = per_token_quant_fp8_to_int8(dy_fp8._transpose.view(dtype=torch.float8_e5m2), dy_fp8._scale_inv, False)
# x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._transpose.view(dtype=torch.float8_e5m2), x_fp8._scale_inv, False)
dw_int32 = tex.generic_gemm(
x_int8,
transa,
dy_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
torch.cuda.synchronize()
end = time.time()
\ No newline at end of file
......@@ -718,10 +718,14 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
NVTE_ERROR("TT layout not allowed.");
}
const bool use_int8 = is_int8_dtype(inputA->data.dtype) ||
is_int8_dtype(inputB->data.dtype);
const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)) {
NVTE_CHECK(!use_int8, "Int8 gemm just surpport pure int8 gemm without any epilogue.");
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, transa, transb, grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 0, 0,
false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
......
......@@ -9,6 +9,22 @@
namespace {
template<typename T>
void printTensor(const std::string& str, const T* devTensor, size_t size) {
T* hostTensor;
hostTensor = (T*)malloc(size * sizeof(T));
hipMemcpy(hostTensor, devTensor, size * sizeof(T), hipMemcpyDeviceToHost);
std::cout << str << ": ";
for(int i; i<size; i++) {
if (i % 16 == 0) {
std::cout << std::endl;
}
std::cout << static_cast<float>(hostTensor[i]) << ", ";
}
std::cout << str << ": finish" << std::endl;
free(hostTensor);
}
hipblasDatatype_t get_hip_dtype(const transformer_engine::DType t) {
using namespace transformer_engine;
switch (t) {
......@@ -18,6 +34,10 @@ hipblasDatatype_t get_hip_dtype(const transformer_engine::DType t) {
return HIPBLAS_R_32F;
case DType::kBFloat16:
return HIPBLAS_R_16B;
case DType::kInt8:
return HIPBLAS_R_8I;
case DType::kInt32:
return HIPBLAS_R_32I;
default:
NVTE_ERROR("Invalid type");
}
......@@ -82,6 +102,16 @@ void hipblas_gemm(const Tensor *inputA,
float one = 1.0f;
float zero = 0.0f;
float beta = accumulate ? one : zero;
int int_one = 1;
int int_zero = 0;
int int_beta = int_zero;
bool use_int8 = false;
if ((A_type == HIPBLAS_R_8I) && (B_type == HIPBLAS_R_8I) && (D_type == HIPBLAS_R_32I)) {
NVTE_CHECK(!accumulate, "Int8 gemm not support accumulate.");
use_int8 = true;
computeType = HIPBLAS_R_32I;
}
hipblasSetStream(handle, stream);
// execute multiply
......@@ -92,20 +122,20 @@ void hipblas_gemm(const Tensor *inputA,
m,
n,
k,
static_cast<const void*>(&one),
use_int8 ? static_cast<const void*>(&int_one) : static_cast<const void*>(&one),
A,
A_type,
lda,
B,
B_type,
ldb,
static_cast<const void*>(&beta),
use_int8 ? static_cast<const void*>(&int_beta) : static_cast<const void*>(&beta),
D,
D_type,
ldd,
computeType,
HIPBLAS_GEMM_DEFAULT);
// printTensor<int32_t>("D_tensor: ", reinterpret_cast<int32_t*>(D), 10);
if (status != HIPBLAS_STATUS_SUCCESS) {
NVTE_ERROR("hipblasGemmEx execution failed");
}
......
......@@ -84,6 +84,8 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
*/
void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream);
void nvte_compute_channel_colwise_amax(const NVTETensor input, NVTETensor output, const NVTETensor fp8_scale, cudaStream_t stream);
/*! \brief Update an FP8 tensor's scale based on its amax.
*
* This is only supported for FP8 tensors with per-tensor scaling.
......
......@@ -387,6 +387,15 @@ enum class DType {
kNumTypes
};
/*! \brief Check if TE datatype is INT8
*
* Return true if TE datatype is INT8
* \param[in] DType TE Datatype of interest
*/
inline bool is_int8_dtype(const DType t) {
return t == DType::kInt8;
}
/*! \brief Check if TE datatype is FP8
*
* Return true if TE datatype is FP8
......
......@@ -16,7 +16,10 @@
#include "recipe_common.cuh"
#ifdef __HIP_PLATFORM_AMD__
#include <hipcub/hipcub.hpp>
using __nv_bfloat16 = __hip_bfloat16;
constexpr int kColwiseReduceTileSize = 32;
constexpr int THREADS_PER_BLOCK = 1024;
#endif
namespace transformer_engine {
......@@ -61,6 +64,65 @@ __launch_bounds__(amax_kernel_threads) __global__
}
}
template <typename T>
__inline__ __device__ T WarpReduceMax(T val, int max = 32) {
for (int offset = max; offset > 0; offset >>= 1) {
val = fmaxf(__shfl_down(val, offset), val);
}
return val;
}
template <int nvec, typename InputType>
__launch_bounds__(1024) __global__
void channel_colwise_amax_kernel(float *dst, const InputType *src, const float *fp8_scale, int M, int N) {
__shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize];
const int j = blockIdx.x * blockDim.x + threadIdx.x;
float channel_amax = 0.f;
float scale = fp8_scale[0];
if (j < N) {
for (int i = threadIdx.y; i < M; i += blockDim.y) {
channel_amax = fmaxf(fabsf(static_cast<float>(src[i * N + j]) * scale), channel_amax);
}
}
g_shared[threadIdx.y][threadIdx.x] = channel_amax;
__syncthreads();
float amax = g_shared[threadIdx.x][threadIdx.y];
amax = WarpReduceMax<float>(amax, kColwiseReduceTileSize / 2);
if (threadIdx.x == 0) {
const int j = blockIdx.x * blockDim.x + threadIdx.y;
if (j < N) {
dst[j] = static_cast<float>(amax) / 127.0; // scales
}
}
}
template <typename InputType>
__launch_bounds__(THREADS_PER_BLOCK) __global__
void channel_colwise_amax_kernel_v2(const InputType* in, float* out, const float* fp8_scale, int m, int n) {
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage;
float scale = fp8_scale[0];
int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK);
int THREADS_PER_COL = BLOCKS_PER_COL * THREADS_PER_BLOCK;
int idx = threadIdx.x + blockIdx.x * blockDim.x;
int col_idx = idx / THREADS_PER_COL;
int row_idx = idx % THREADS_PER_COL;
float thread_data;
if (row_idx < m)
thread_data = fabsf((float)in[row_idx * n + col_idx] * scale);
float local_amax;
if (row_idx < (BLOCKS_PER_COL-1) * THREADS_PER_BLOCK) {
local_amax = BlockReduce(block_temp_storage).Reduce(thread_data, hipcub::Max());
} else {
local_amax = BlockReduce(block_temp_storage).Reduce(thread_data, hipcub::Max(), m - (BLOCKS_PER_COL - 1) * THREADS_PER_BLOCK);
}
if (threadIdx.x == 0) {
atomicMax(&out[col_idx], local_amax);
out[col_idx] = out[col_idx] / 127.0;
}
}
template <int nvec, typename InputType>
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
// Zero out amax so we can update with atomic max
......@@ -103,6 +165,26 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <int nvec, typename InputType>
void launch_channel_colwise_amax_kernel(const InputType *input, float *amax, const float *fp8_scale, const size_t M, const size_t N, cudaStream_t stream) {
// Zero out amax so we can update with atomic max
cudaMemsetAsync(amax, 0, N * sizeof(float), stream);
// Launch kernel
int B =(N - 1) / kColwiseReduceTileSize + 1;
channel_colwise_amax_kernel<nvec, InputType><<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(amax, input, fp8_scale, M, N);
// Launch kernel v2
// dim3 block, grid;
// int BLOCKS_PER_COL = ceil(float(M) / THREADS_PER_BLOCK);
// block.x = THREADS_PER_BLOCK;
// grid.x = BLOCKS_PER_COL * N;
// hipLaunchKernelGGL((channel_colwise_amax_kernel_v2<InputType>), dim3(grid), dim3(block), 0, stream, input, amax, fp8_scale, M, N);
// Check results
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace
} // namespace transformer_engine
......@@ -150,6 +232,40 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
stream);); // NOLINT(*)
}
void nvte_compute_channel_colwise_amax(const NVTETensor input_, const NVTETensor output_, const NVTETensor fp8_scale_, cudaStream_t stream) {
NVTE_API_CALL(nvte_compute_channel_colwise_amax);
using namespace transformer_engine;
// Check input tensor
NVTE_CHECK(input_ != nullptr, "Invalid input tensor (got NULL)");
NVTE_CHECK(fp8_scale_ != nullptr, "Invalid fp8 scale tensor (got NULL)");
const auto &input = *convertNVTETensorCheck(input_);
const auto &fp8_scale = *convertNVTETensorCheck(fp8_scale_);
NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor for amax computation must unquantized, "
"but got scaling_mode=",
to_string(input.scaling_mode));
NVTE_CHECK(input.data.dptr != nullptr, "Input tensor for amax computation has no data");
// Check output tensor
NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
auto &output = *convertNVTETensorCheck(output_);
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling, "
"but got scaling_mode=",
to_string(output.scaling_mode));
CheckOutputTensor(output, "output_compute_amax", true);
// Compute amax
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
launch_channel_colwise_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<float *>(output.data.dptr), reinterpret_cast<const float *>(fp8_scale.data.dptr),
input.data.shape[0],
input.data.shape[1],
stream);); // NOLINT(*)
}
namespace transformer_engine {
namespace {
......
......@@ -250,6 +250,7 @@ inline size_t typeToSize(transformer_engine::DType t) {
case transformer_engine::DType::kByte:
case transformer_engine::DType::kFloat8E4M3:
case transformer_engine::DType::kFloat8E5M2:
case transformer_engine::DType::kInt8:
return 1;
default:
NVTE_ERROR("Invalid type");
......@@ -258,6 +259,8 @@ inline size_t typeToSize(transformer_engine::DType t) {
inline at::ScalarType GetATenDType(transformer_engine::DType t) {
switch (t) {
case transformer_engine::DType::kInt8:
return torch::kInt8;
case transformer_engine::DType::kInt16:
return torch::kInt16;
case transformer_engine::DType::kInt32:
......@@ -303,6 +306,8 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
return transformer_engine::DType::kInt32;
case torch::kInt64:
return transformer_engine::DType::kInt64;
case torch::kInt8:
return transformer_engine::DType::kInt8;
default:
std::cout << "Type: " << static_cast<int>(t) << std::endl;
NVTE_ERROR("Invalid type");
......
......@@ -261,6 +261,8 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads
void compute_amax(const at::Tensor &tensor, at::Tensor &amax);
void compute_channel_colwise_amax(const at::Tensor &tensor, at::Tensor &amax, at::Tensor &fp8_scale);
void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
std::vector<at::Tensor> amax_histories,
std::vector<at::Tensor> scales,
......
......@@ -216,6 +216,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("compute_amax", &transformer_engine::pytorch::compute_amax,
"Compute absolute max value in tensor", py::arg("input"), py::arg("amax"),
py::call_guard<py::gil_scoped_release>());
m.def("compute_channel_colwise_amax", &transformer_engine::pytorch::compute_channel_colwise_amax,
"Compute colwise absolute max value in tensor", py::arg("input"), py::arg("amax"), py::arg("fp8_scale"),
py::call_guard<py::gil_scoped_release>());
m.def("fused_amax_and_scale_update_after_reduction",
&transformer_engine::pytorch::fused_amax_and_scale_update_after_reduction,
"Update amax history and FP8 scale/scale_inv after reduction",
......
......@@ -28,6 +28,19 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream());
}
void compute_channel_colwise_amax(const at::Tensor& tensor, at::Tensor& amax, at::Tensor& fp8_scale) {
auto input_tensor = tensor.contiguous();
const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
const TensorWrapper& te_fp8_scale = makeTransformerEngineTensor(fp8_scale);
TORCH_CHECK(amax.scalar_type() == at::kFloat, "amax must be a float tensor");
TORCH_CHECK(fp8_scale.numel() == 1, "fp8_scale must have exactly one element");
TORCH_CHECK(te_input.shape().ndim == 2, "input ndim must be 2");
const TensorWrapper& te_amax = makeTransformerEngineTensor(amax);
nvte_compute_channel_colwise_amax(te_input.data(), te_amax.data(), te_fp8_scale.data(), at::cuda::getCurrentCUDAStream());
}
void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reduction_buffer,
std::vector<at::Tensor> amax_histories,
std::vector<at::Tensor> scales,
......
......@@ -10,7 +10,326 @@ import pandas as pd
import logging
import math
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import transformer_engine_torch as tex
from triton.language.extra import libdevice
@triton.jit
def _per_token_quant_int8(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
BLOCK: tl.constexpr,
):
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
row_id = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
def per_token_quant_int8(x):
M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
assert x.is_contiguous()
_per_token_quant_int8[(M, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
@triton.jit
def _per_token_quant_int8_v2(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
M,
BLOCK: tl.constexpr,
):
col_id = tl.program_id(0)
rows = tl.arange(0, BLOCK)
mask = rows < M
x = tl.load(x_ptr + rows * stride_x + col_id, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + rows * stride_xq + col_id, x_q, mask=mask)
tl.store(scale_ptr + col_id, scale_x)
def per_token_quant_int8_v2(x):
assert x.is_contiguous()
x = x.view(-1, x.shape[-1])
M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty((1, x.shape[-1]),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(M)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
_per_token_quant_int8_v2[(N, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
M=M,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
@triton.jit
def _per_token_quant_fp8_to_int8(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
fp8_scale_inv,
BLOCK: tl.constexpr,
):
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
row_id = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
fp8_scale = tl.load(fp8_scale_inv)
x = x * fp8_scale
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
def per_token_quant_fp8_to_int8(x, fp8_scale_inv, inplace=False):
M = x.numel() // x.shape[-1]
N = x.shape[-1]
if inplace:
x_q = x.view(dtype=torch.int8)
else:
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
assert x.is_contiguous()
_per_token_quant_fp8_to_int8[(M, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
BLOCK=BLOCK,
fp8_scale_inv=fp8_scale_inv,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
@triton.jit
def _per_token_quant_fp8_to_int8_opt(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
fp8_scale_inv,
BLOCK: tl.constexpr,
):
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
row_id = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
scales = tl.load(scale_ptr + cols, mask=mask,
other=0.0).to(tl.float32)
fp8_scale = tl.load(fp8_scale_inv)
x = x * fp8_scale
x_q = x / scales
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
def per_token_quant_fp8_to_int8_opt(x, fp8_scale_inv, inplace=False):
assert x.is_contiguous()
x = x.view(-1, x.shape[-1])
M = x.numel() // x.shape[-1]
N = x.shape[-1]
if inplace:
x_q = x.view(dtype=torch.int8)
else:
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty((1, x.shape[-1]),
device=x.device,
dtype=torch.float32)
tex.compute_channel_colwise_amax(x, scales, fp8_scale_inv)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
_per_token_quant_fp8_to_int8_opt[(M, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
BLOCK=BLOCK,
fp8_scale_inv=fp8_scale_inv,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
@triton.jit
def _per_token_quant_fp8_to_int8_v2(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
M,
fp8_scale_inv,
BLOCK: tl.constexpr,
BLOCK_N: tl.constexpr,
):
col_id = tl.program_id(0)
rows = tl.arange(0, BLOCK)
cols = tl.arange(0, BLOCK_N)
offset_cols = col_id * BLOCK_N + cols
mask = rows[:,None] < M
x = tl.load(x_ptr + rows[:,None] * stride_x + offset_cols[None,:], mask=mask, other=0.0).to(tl.float32)
fp8_scale = tl.load(fp8_scale_inv)
x = x * fp8_scale
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + rows[:,None] * stride_xq + offset_cols[None,:], x_q, mask=mask)
tl.store(scale_ptr + offset_cols, scale_x)
def per_token_quant_fp8_to_int8_v2(x, fp8_scale_inv, inplace=False):
assert x.is_contiguous()
x = x.view(-1, x.shape[-1])
M = x.numel() // x.shape[-1]
N = x.shape[-1]
if inplace:
x_q = x.view(dtype=torch.int8)
else:
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty((1, x.shape[-1]),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(M)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
_per_token_quant_fp8_to_int8_v2[(N//32, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
M=M,
fp8_scale_inv=fp8_scale_inv,
BLOCK=BLOCK,
BLOCK_N=triton.next_power_of_2(32),
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
@torch.compile(mode="max-autotune-no-cudagraphs")
def channelwise_dequantize(A, B, C):
out_scales = A * B
return (out_scales * C.to(dtype=torch.float32)).to(torch.bfloat16)
@torch.compile(mode="max-autotune-no-cudagraphs")
def channelwise_dequantize_transA(A, B, C):
out_scales = A.T * B
return (out_scales * C.to(dtype=torch.float32)).to(torch.bfloat16)
@torch.compile(mode="max-autotune-no-cudagraphs")
def channelwise_dequantize_transA_float(A, B, C):
out_scales = A.T * B
return out_scales * C.to(dtype=torch.float32)
@torch.compile(mode="max-autotune-no-cudagraphs")
def channelwise_dequantize_transB(A, B, C):
out_scales = A * B.T
return (out_scales * C.to(dtype=torch.float32)).to(torch.bfloat16)
def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment