Commit 7a923605 authored by yuguo's avatar yuguo
Browse files

[DCU] tensorwise int8 train opt

parent 686e93cd
......@@ -32,11 +32,56 @@ import os
int8_simulation_fp8_tensorwise = bool(int(os.getenv("NVTE_INT8_SIM_FP8_TENSORWISE", "0")))
tensorwise_int8_check = bool(int(os.getenv("NVTE_INT8_SIM_FP8_TENSORWISE_CHECK", "0")))
def dtype_tols(dtype: torch.dtype) -> Dict[str, float]:
"""Estimated numerical error for a datatype
Based on tolerances for torch.testing.assert_close.
"""
if dtype == torch.float32:
return dict(rtol=1.3e-6, atol=1e-5)
if dtype == torch.float16:
return dict(rtol=1e-3, atol=1e-5)
if dtype == torch.bfloat16:
return dict(rtol=1.6e-2, atol=1e-5)
raise ValueError(f"Unsuppored dtype ({dtype})")
def assert_allclose(
l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float = None, rtol: float = None
) -> bool:
"""Ensures two lists are equal."""
assert len(l1) == len(l2), "Unequal number of outputs."
for i, (t1, t2) in enumerate(zip(l1, l2)):
tols = dtype_tols(t2.dtype)
if rtol is not None:
tols["rtol"] = rtol
if atol is not None:
tols["atol"] = atol
result = torch.allclose(t1, t2, **tols)
if not result:
diff = torch.abs(t1 - t2)
tol = tols["atol"] + (tols["rtol"] * torch.abs(t2))
exceed_mask = diff > tol
if exceed_mask.any():
indices = torch.nonzero(exceed_mask, as_tuple=True)
max_diff = diff[exceed_mask].max()
max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
max_location = [idx[max_idx].item() for idx in indices]
msg = (
f"Outputs not close enough in tensor at idx={i}. "
f"Maximum difference at location {max_location} "
f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
f"(diff {max_diff.item()})."
)
raise AssertionError(msg)
# TN
m = 4096
k = 4096
n = 4096
seed = 0
n = 6144
seed = 4096
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
device = "cuda"
......@@ -235,69 +280,94 @@ transb = False
x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
output = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
bf16_out = torch.matmul(x_bf16, w_bf16.t())
torch.cuda.synchronize()
start = time.time()
for i in range(20):
bf16_out = torch.matmul(x_bf16, w_bf16.t())
torch.cuda.synchronize()
end = time.time()
# x_int8, x_scales = per_token_quant_int8(x_bf16)
# w_int8, w_scales = per_token_quant_int8(w_bf16)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# Cast to FP8 and back
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e4m3fn))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e4m3fn))
if int8_simulation_fp8_tensorwise:
x_int8, x_scales = x_fp8._data.view(dtype=torch.int8), x_fp8._scale_inv
w_int8, w_scales = w_fp8._data.view(dtype=torch.int8), w_fp8._scale_inv
else:
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e4m3fn), x_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e4m3fn), w_fp8._scale_inv, False)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
y_int32 = tex.generic_gemm(
w_int8,
transa,
x_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
# y_int32 = torch._int_mm(x_int8, w_int8.t())
# print("y_int32: ", y_int32)
bf16_out = torch.matmul(x_bf16, w_bf16.t())
print("bf16_out: ", bf16_out)
torch.cuda.synchronize()
start = time.time()
for i in range(20):
bf16_out = torch.matmul(x_bf16, w_bf16.t())
torch.cuda.synchronize()
end = time.time()
# x_int8, x_scales = per_token_quant_int8(x_bf16)
# w_int8, w_scales = per_token_quant_int8(w_bf16)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# Cast to FP8 and back
x_fp8 = to_float8_CS(x_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
w_fp8 = to_float8_CS(w_bf16, fp8_dtype=tex.DType.kFloat8E4M3)
# print("x_fp8: ", x_fp8._data.view(dtype=torch.float8_e4m3fn))
# print("w_fp8: ", w_fp8._data.view(dtype=torch.float8_e4m3fn))
if int8_simulation_fp8_tensorwise:
x_int8, x_scales = x_fp8._data.view(dtype=torch.int8), x_fp8._scale_inv
w_int8, w_scales = w_fp8._data.view(dtype=torch.int8), w_fp8._scale_inv
else:
x_int8, x_scales = per_token_quant_fp8_to_int8(x_fp8._data.view(dtype=torch.float8_e4m3fn), x_fp8._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(w_fp8._data.view(dtype=torch.float8_e4m3fn), w_fp8._scale_inv, False)
# print("x_int8: ", x_int8)
# print("w_int8: ", w_int8)
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
y_int32 = tex.generic_gemm(
w_int8,
transa,
x_int8,
transb,
out,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
# y_int32 = torch._int_mm(x_int8, w_int8.t())
# print("y_int32: ", y_int32)
if int8_simulation_fp8_tensorwise:
tensorwise_dequantize(x_scales, w_scales, y_int32, output)
else:
output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
print("output: ", output)
if tensorwise_int8_check:
lt_output = tex.generic_gemm(
w_fp8,
transa,
x_fp8,
transb,
out,
out_quantizer,
TE_DType[torch.bfloat16],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
True,
)[0]
print("lt_output: ", lt_output)
assert_allclose([output], [lt_output])
# print("out_scales.shape: ", out_scales.shape)
# print("out_scales: ", out_scales)
if int8_simulation_fp8_tensorwise:
tensorwise_dequantize(x_scales, w_scales, y_int32, output)
else:
output = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
# print("out_scales.shape: ", out_scales.shape)
# print("out_scales: ", out_scales)
print("bf16_out: ", bf16_out)
print("output: ", output)
# torch.cuda.synchronize()
# start = time.time()
......@@ -339,6 +409,7 @@ w_bf16 = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
dx = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
bf16_dx = torch.matmul(dy_bf16, w_bf16)
print("bf16_dx: ", bf16_dx)
torch.cuda.synchronize()
start = time.time()
......@@ -397,11 +468,32 @@ if int8_simulation_fp8_tensorwise:
else:
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
# dx = channelwise_dequantize_transB(dy_scales, w_scales, dx_int32)
print("dx: ", dx)
if tensorwise_int8_check:
lt_dx = tex.generic_gemm(
w_fp8,
transa,
dy_fp8,
transb,
out,
out_quantizer,
TE_DType[torch.bfloat16],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
True,
)[0]
print("lt_dx: ", lt_dx)
assert_allclose([dx], [lt_dx])
# print("dx_scales.shape: ", dx_scales.shape)
# print("dx_scales: ", dx_scales)
print("bf16_dx: ", bf16_dx)
print("dx: ", dx)
# torch.cuda.synchronize()
# start = time.time()
......@@ -447,11 +539,9 @@ transb = True
dy_bf16 = (torch.randn((m, n), device=device)).to(dtype=torch.bfloat16)
x_bf16 = (torch.randn((m, k), device=device)).to(dtype=torch.bfloat16)
dw = (torch.randn((n, k), device=device)).to(dtype=torch.bfloat16)
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
bf16_dw = torch.matmul(dy_bf16.t(), x_bf16)
print("bf16_dw: ", bf16_dw)
torch.cuda.synchronize()
start = time.time()
......@@ -504,9 +594,30 @@ if int8_simulation_fp8_tensorwise:
else:
dw = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
print("bf16_dw: ", bf16_dw)
print("dw: ", dw)
if tensorwise_int8_check:
lt_dw = tex.generic_gemm(
x_fp8,
transa,
dy_fp8,
transb,
out,
out_quantizer,
TE_DType[torch.bfloat16],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
True,
)[0]
print("lt_dw: ", lt_dw)
assert_allclose([dw], [lt_dw])
# torch.cuda.synchronize()
# start = time.time()
# for i in range(20):
......@@ -548,9 +659,9 @@ print("dw: ", dw)
# bacth gemm wgrad
m = 32
k = 32
n = 32
m = 1024
k = 1024
n = 1024
b = 4
transa = False
......@@ -558,9 +669,6 @@ transb = True
dy_int8 = (torch.randn((b, m, n), device=device)).to(dtype=torch.int8)
x_int8 = (torch.randn((b, m, k), device=device)).to(dtype=torch.int8)
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
int32_dw_list = []
for i in range(b):
......@@ -597,3 +705,92 @@ te_dw = tex.generic_batchgemm(
# print("te_dw.shape: ", te_dw.view(b, -1, te_dw.size(-1)).shape)
# print("te_dw: ", te_dw.view(b, -1, te_dw.size(-1)))
torch.testing.assert_close(te_dw.view(b, -1, te_dw.size(-1)), batched_int32_dw, atol=0, rtol=0)
# NT
b = 4
transa = False
transb = True
dy_bf16 = [(torch.randn((m, n), device=device)).to(dtype=torch.bfloat16) for i in range(b)]
x_bf16 = [(torch.randn((m, k), device=device)).to(dtype=torch.bfloat16) for i in range(b)]
dw_ref = [(torch.randn((n, k), device=device)).to(dtype=torch.bfloat16) for i in range(b)]
dw = [(torch.randn((n, k), device=device)).to(dtype=torch.bfloat16) for i in range(b)]
# Cast to FP8 and back
dy_fp8 = [to_float8_CS(dy_bf16[i], fp8_dtype=tex.DType.kFloat8E5M2) for i in range(b)]
x_fp8 = [to_float8_CS(x_bf16[i], fp8_dtype=tex.DType.kFloat8E5M2) for i in range(b)]
if int8_simulation_fp8_tensorwise:
dy_int8, dy_scales = [dy_fp8[i]._data.view(dtype=torch.int8) for i in range(b)], [dy_fp8[i]._scale_inv for i in range(b)]
x_int8, x_scales = [x_fp8[i]._data.view(dtype=torch.int8) for i in range(b)], [x_fp8[i]._scale_inv for i in range(b)]
else:
dy_int8, dy_scales = [], []
x_int8, x_scales = [], []
assert False
for i in range(b):
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
dw_int32 = tex.generic_gemm(
x_int8[i],
transa,
dy_int8[i],
transb,
None,
None,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
if int8_simulation_fp8_tensorwise:
tensorwise_dequantize(dy_scales[i], x_scales[i], dw_int32, dw_ref[i])
else:
assert False
dw_ref_tensor = torch.stack(dw_ref).contiguous()
# print("dw_ref_tensor: ", dw_ref_tensor)
torch.cuda.synchronize()
dy_int8_tensor = torch.stack(dy_int8).contiguous()
dy_scales_tensor = torch.stack(dy_scales).contiguous()
x_int8_tensor = torch.stack(x_int8).contiguous()
x_scales_tensor = torch.stack(x_scales).contiguous()
dw_tensor = torch.stack(dw).contiguous()
out_dtype = torch.bfloat16
dw_tensor = tex.tensorwise_int8_batchgemm(
x_int8_tensor.view(-1, x_int8.size(-1)),
transa,
dy_int8_tensor.view(-1, dy_int8.size(-1)),
transb,
x_scales_tensor,
dy_scales_tensor,
dw_tensor.view(-1, dw.size(-1)),
b,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
# print("dw_tensor: ", dw_tensor)
torch.testing.assert_close(dw_ref_tensor, dw_tensor, atol=1e-5, rtol=1e-5)
......@@ -684,8 +684,9 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
const char *NVTE_FORCE_ROCM_GEMM = std::getenv("NVTE_FORCE_ROCM_GEMM");
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype);
const char *NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");
if (NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1' && use_int8 && use_split_accumulator) nvte_use_hipblaslt = 1;
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr!=nullptr) || (use_fp8) || (NVTE_FORCE_ROCM_GEMM != nullptr && NVTE_FORCE_ROCM_GEMM[0] == '1') || (nvte_use_hipblaslt) || (nvte_use_rocblas)) {
NVTE_CHECK(!use_int8, "Int8 gemm just surpport pure int8 gemm without any epilogue.");
cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, m, n, k, lda, ldb, ldd, transa, transb, grad,
wspace->data.dptr, wspace->data.shape[0], accumulate, use_split_accumulator, math_sm_count, 0, 0,
false, nullptr, stream, nvte_use_hipblaslt, nvte_use_rocblas, compute_stream_offset);
......@@ -1100,4 +1101,78 @@ void nvte_cublas_batchgemm_v2(const NVTETensor A, const NVTETensor B, NVTETensor
batch_count,
stream);
}
// add for batchgemm
void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTETensor A_scales, const NVTETensor B_scales, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_batchgemm_v3);
using namespace transformer_engine;
const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B);
const Tensor *inputA_scales = convertNVTETensorCheck(A_scales);
const Tensor *inputB_scales = convertNVTETensorCheck(B_scales);
Tensor *outputD = convertNVTETensor(D);
const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
Tensor *wspace = convertNVTETensor(workspace);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr != nullptr)) {
NVTE_ERROR("MOE batchgemm not surpport bias or gelu.");
}
int m, n, k;
if (!transa && transb) {
// for NT
m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
} else if(transa && !transb){
// for TN
m = transa ? inputA->data.shape[0]/batch_count: inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
} else if(!transa && !transb){
// for NN
m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count; }
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}
hipblasLtHandle_t handle = nullptr;
// Init hipblaslt handles (once, globally)
static std::once_flag init_flag;
static hipblasLtHandle_t hipblaslt_handles[compute_num_streams];
std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);
handle = hipblaslt_handles[0];
hipblaslt_batchgemm_tensorwise_int8(inputA, inputB, inputA_scales, inputB_scales, outputD, biasTensor, outputGelu,
m, n, k, lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad,
wspace->data.dptr,
wspace->data.shape[0], accumulate, use_split_accumulator,
math_sm_count, 0, 0,
false, nullptr, batch_count, stream,
handle);
}
#endif
\ No newline at end of file
......@@ -971,14 +971,17 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
hipblasLtHandle_t handle) {
void* A = inputA->data.dptr;
void* A_scale_inverse = inputA->scale_inv.dptr;
float* A_scale_inverse_float = (float*)(inputA->scale_inv.dptr);
void* B = inputB->data.dptr;
void* B_scale_inverse = inputB->scale_inv.dptr;
float* B_scale_inverse_float = (float*)(inputB->scale_inv.dptr);
void* D = outputD->data.dptr;
void* bias_ptr = inputBias->data.dptr;
const bool bias = bias_ptr != nullptr;
void* pre_gelu_out = outputPreGelu->data.dptr;
const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype);
const bool use_int8 = is_int8_dtype(inputA->data.dtype) || is_int8_dtype(inputB->data.dtype);
const hipDataType A_type = get_hipblaslt_dtype(inputA->data.dtype);
const hipDataType B_type = get_hipblaslt_dtype(inputB->data.dtype);
const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype);
......@@ -988,11 +991,19 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
"FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
"FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_int8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
"INT8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_int8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
"INT8 input to GEMM requires inverse of scale!");
bool tensorwise_int8 = 0;;
const char* NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");
if (NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1' && use_int8) tensorwise_int8 = 1;
// check consistency of arguments:
// if fp8 is desired, context cannot be null
// fp8 + gelu fusion + fp8 aux is unavailable right now.
if (use_fp8) {
if (use_fp8 || use_int8) {
NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!");
}
float one = 1.0;
......@@ -1014,6 +1025,17 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
hipblasLtMatmulPreference_t preference = nullptr;
hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT;
hipblasLtMatmulFlags_t matmul_flag = HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_BF16;
if (tensorwise_int8) {
if (D_type == HIP_R_16BF) {
matmul_flag = HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_BF16;
} else if (D_type == HIP_R_32F) {
matmul_flag = HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_FP32;
} else {
NVTE_CHECK(false, "tensorwise_int8 only surpport D_type bf16 or fp32!");
}
}
int64_t ld_gelumat = (int64_t)ldd;
// default to tf32 except for e5m2 inputs where the config is not supported
......@@ -1026,7 +1048,11 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
transb == HIPBLAS_OP_N ? n : k, ldb));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F));
if (tensorwise_int8) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F, matmul_flag));
} else {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F));
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA,
&transa, sizeof(transa)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB,
......@@ -1055,6 +1081,19 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
}
}
if (tensorwise_int8) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
(void*)&A_scale_inverse_float,
sizeof(void*)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
(void*)&B_scale_inverse_float,
sizeof(void*)));
if (bias) {
NVTE_CHECK(false, "tensorwise_int8 not surpport bias!");
}
}
if (bias && gelu) {
if (grad) {
......@@ -1260,6 +1299,422 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
}
void hipblaslt_batchgemm_tensorwise_int8(const Tensor *inputA,
const Tensor *inputB,
const Tensor *inputA_scales,
const Tensor *inputB_scales,
Tensor *outputD,
const Tensor *inputBias,
Tensor *outputPreGelu,
int m, int n, int k,
int lda, int ldb, int ldd,
hipblasOperation_t transa,
hipblasOperation_t transb,
bool grad,
void* workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
size_t batch_count,
hipStream_t stream,
hipblasLtHandle_t handle
) {
void *A = inputA->data.dptr;
void *A_scale_inverse = inputA_scales->data.dptr;
float *A_scale_inverse_float = (float*)(inputA_scales->data.dptr);
void *B = inputB->data.dptr;
void *B_scale_inverse = inputB_scales->data.dptr;
float *B_scale_inverse_float = (float*)(inputB_scales->data.dptr);
void *D = outputD->data.dptr;
void *bias_ptr = inputBias->data.dptr;
const bool bias = bias_ptr != nullptr;
void *pre_gelu_out = outputPreGelu->data.dptr;
const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) ||
is_fp8_dtype(inputB->data.dtype);
const bool use_int8 = is_int8_dtype(inputA->data.dtype) ||
is_int8_dtype(inputB->data.dtype);
const hipDataType A_type = get_hipblaslt_dtype(inputA->data.dtype);
const hipDataType B_type = get_hipblaslt_dtype(inputB->data.dtype);
const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype);
const hipDataType bias_type = get_hipblaslt_dtype(inputBias->data.dtype);
NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
"FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
"FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_int8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
"INT8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_int8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
"INT8 input to GEMM requires inverse of scale!");
bool tensorwise_int8 = 0;;
const char *NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");
if (NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1' && use_int8) tensorwise_int8 = 1;
// check consistency of arguments:
// if fp8 is desired, context cannot be null
// fp8 + gelu fusion + fp8 aux is unavailable right now.
if (use_fp8 || use_int8) {
NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!");
}
float one = 1.0;
float zero = 0.0;
float beta = (accumulate) ? one : zero;
int device_id;
NVTE_CHECK_CUDA(hipGetDevice(&device_id));
if (handle == nullptr) {
handle = cached_handles.get(device_id);
if (handle == nullptr)
{
handle = cached_handles.obtain(device_id);
}
}
hipblasLtMatmulDesc_t operationDesc = nullptr;
hipblasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
hipblasLtMatmulPreference_t preference = nullptr;
hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT;
hipblasLtMatmulFlags_t matmul_flag = HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_BF16;
if (tensorwise_int8) {
if (D_type == HIP_R_16BF) {
matmul_flag = HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_BF16;
} else if (D_type == HIP_R_32F) {
matmul_flag = HIPBLASLT_MATMUL_FLAGS_INT8_SCALE_FP32;
} else {
NVTE_CHECK(false, "tensorwise_int8 only surpport D_type bf16 or fp32!");
}
}
int64_t ld_gelumat = (int64_t) ldd;
// default to tf32 except for e5m2 inputs where the config is not supported
hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F;
// Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type,
transa == HIPBLAS_OP_N ? m : k,
transa == HIPBLAS_OP_N ? k : m,
lda));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Bdesc, B_type,
transb == HIPBLAS_OP_N ? k : n,
transb == HIPBLAS_OP_N ? n : k,
ldb));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
if (tensorwise_int8) {
size_t strideA = m*k;
size_t strideB = k*n;
size_t strideD = m*n;
hipblasLtMatrixLayoutSetAttribute(Adesc, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t));
hipblasLtMatrixLayoutSetAttribute(Adesc, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideA, sizeof(int64_t));
hipblasLtMatrixLayoutSetAttribute(Bdesc, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t));
hipblasLtMatrixLayoutSetAttribute(Bdesc, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideB, sizeof(int64_t));
hipblasLtMatrixLayoutSetAttribute(Ddesc, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(int32_t));
hipblasLtMatrixLayoutSetAttribute(Ddesc, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &strideD, sizeof(int64_t));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F, matmul_flag));
} else {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F));
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA,
&transa, sizeof(transa)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB,
&transb, sizeof(transb)));
// set fp8 attributes -- input and output types should already be set to fp8 as appropriate
// Note: gelu fusion isn't available right now, and we don't need
// amax(D) either (next op is high precision).
if (use_fp8) {
// Split accumulator.
const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
/*
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_FAST_ACCUM, //TODO: We don't have fast accum mode yet
&fastAccuMode,
sizeof(fastAccuMode)));
*/
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&A_scale_inverse,
sizeof(A_scale_inverse)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse,
sizeof(B_scale_inverse)));
if (bias) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE,
&bias_type, sizeof(bias_type)));
}
}
if (tensorwise_int8) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
(void*)&A_scale_inverse_float,
sizeof(void*)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
(void*)&B_scale_inverse_float,
sizeof(void*)));
if (bias) {
NVTE_CHECK(false, "tensorwise_int8 not surpport bias!");
}
}
if (bias && gelu) {
if (grad) {
epilogue = HIPBLASLT_EPILOGUE_DGELU_BGRAD;
} else {
epilogue = HIPBLASLT_EPILOGUE_GELU_AUX_BIAS;
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr, sizeof(bias_ptr)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&pre_gelu_out, sizeof(pre_gelu_out)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&ld_gelumat, sizeof(ld_gelumat)));
} else if (bias) {
if (grad) {
// grad output is always input B
epilogue = HIPBLASLT_EPILOGUE_BGRADB;
} else {
epilogue = HIPBLASLT_EPILOGUE_BIAS;
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr, sizeof(bias_ptr)));
} else if (gelu) {
if (grad) {
epilogue = HIPBLASLT_EPILOGUE_DGELU;
} else {
epilogue = HIPBLASLT_EPILOGUE_GELU_AUX;
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&pre_gelu_out, sizeof(pre_gelu_out)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&ld_gelumat, sizeof(ld_gelumat)));
}
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue)));
GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type,
use_fp8 ? bias_type : (hipDataType)-1,
m, n, k, lda, ldb, ldd, transa, transb, epilogue );
GemmAlgoCache::Algo cached_algo;
if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value())
{
int firstAlgo = getIntEnv("TE_HIPBLASLT_ALGO_SELECTION", 0, 0);
int tuneLoopCount = getIntEnv("TE_HIPBLASLT_TUNING_RUN_COUNT", 0, 0);
int algoTuneCount = 1;
std::vector<hipblasLtMatmulHeuristicResult_t> algoArr;
bool logTuning = getIntEnv("TE_HIPBLASLT_LOG_TUNING", 0, 0) != 0;
if (tuneLoopCount)
{
/* HIPBLASLT may return hundreds of algos for some configs
* Limit amount by default. User may override with env
*/
static const int defaultAlgoCount = 16;
algoTuneCount = getIntEnv("TE_HIPBLASLT_TUNING_ALGO_COUNT", defaultAlgoCount, 1);
}
algoTuneCount += firstAlgo;
int algoTotalCount = cached_algo.hasId() ? std::max(algoTuneCount, (cached_algo.index + 1)) : algoTuneCount;
algoArr.resize(algoTotalCount);
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceCreate(&preference));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceSetAttribute(
preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspaceSize, sizeof(workspaceSize)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Ddesc,
Ddesc, preference, algoTotalCount, algoArr.data(),
&algoTotalCount));
algoArr.resize(algoTotalCount);
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceDestroy(preference));
//If cached algo exists in persistent storage we just need to find matching hipblasLtMatmulAlgo_t
if (cached_algo.hasId())
{
int idx = (cached_algo.index < algoTotalCount) ? cached_algo.index : 0;
for (int i=0; i<algoTotalCount; i++)
{
const auto &algo = algoArr[idx];
if (algo.state == HIPBLAS_STATUS_SUCCESS)
{
if (cached_algo.algoId == cached_algo.getAlgoId(algo.algo))
{
cached_algo.algo = algo.algo;
if (algo.workspaceSize != cached_algo.ws_size_min || idx != cached_algo.index)
{
cached_algo.ws_size_min = algo.workspaceSize;
cached_algo.index = idx;
algoCache.store(gemm_cfg, cached_algo);
}
break;
}
}
idx = (idx + 1) % algoTotalCount;
}
if (logTuning && !cached_algo.algo.has_value())
{
std::cout << "[WARNING] Cannot find cached algoId " << cached_algo.algoId << " in hipBLASLt results" << std::endl;
}
}
//No suitable entry in autotune cache or could not find matched algo in hipBLASLt results
if (!cached_algo.algo.has_value())
{
int bestAlgo = -1;
algoTuneCount = std::min(algoTuneCount, algoTotalCount);
if (tuneLoopCount > 0)
{
if (logTuning)
std::cout << "[INFO] Perform hipBLASLt algo selection on GPU" << device_id
<< " in range [" << firstAlgo << "-" << (algoTuneCount - 1) << "] with "
<< tuneLoopCount << " loops " << std::endl;
NVTE_CHECK_CUDA(hipStreamSynchronize(stream));
hipStream_t profilingStream;
NVTE_CHECK_CUDA(hipStreamCreateWithFlags(&profilingStream, hipStreamNonBlocking));
using tuning_clock = std::chrono::steady_clock;
tuning_clock::now(); //the first call takes little longer so do it outside the loop
tuning_clock::duration bestTime = tuning_clock::duration::max();
for (int algo=firstAlgo; algo<algoTuneCount; algo++)
{
if (algoArr[algo].state != HIPBLAS_STATUS_SUCCESS)
{
continue;
}
// Warm-up call
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle,
operationDesc,
static_cast<const void*>(&one), /* alpha */
A, /* A */
Adesc,
B, /* B */
Bdesc,
static_cast<const void*>(&beta), /* beta */
D, /* C */
Ddesc,
D, /* D */
Ddesc,
&algoArr[algo].algo, /* algo */
workspace, /* workspace */
workspaceSize,
profilingStream)); /* stream */
NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream));
//Profiling loop
tuning_clock::time_point startTime = tuning_clock::now();
for (int loop=0; loop<tuneLoopCount; loop++)
{
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle,
operationDesc,
static_cast<const void*>(&one), /* alpha */
A, /* A */
Adesc,
B, /* B */
Bdesc,
static_cast<const void*>(&beta), /* beta */
D, /* C */
Ddesc,
D, /* D */
Ddesc,
&algoArr[algo].algo, /* algo */
workspace, /* workspace */
workspaceSize,
profilingStream)); /* stream */
}
NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream));
tuning_clock::duration algoTime = tuning_clock::now() - startTime;
if (algoTime < bestTime)
{
bestAlgo = algo;
bestTime = algoTime;
}
}
NVTE_CHECK_CUDA(hipStreamDestroy(profilingStream));
if (bestAlgo >= 0)
{
if (logTuning)
std::cout << "[INFO] Select hipBLASLt algo " << bestAlgo << " with time "
<< std::chrono::duration_cast<std::chrono::nanoseconds>(bestTime).count() / tuneLoopCount
<< " ns" << std::endl;
}
}
else if (firstAlgo < algoTuneCount)
{
bestAlgo = firstAlgo;
}
if (bestAlgo < 0) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
throw std::runtime_error("Unable to find any suitable algorithms");
}
cached_algo.algo = algoArr[bestAlgo].algo;
cached_algo.index = bestAlgo;
cached_algo.algoId = cached_algo.getAlgoId(algoArr[bestAlgo].algo);
cached_algo.ws_size_min = algoArr[bestAlgo].workspaceSize;
cached_algo.ws_size_max = workspaceSize;
if (logTuning)
std::cout << "[INFO] Use hipBLASLt algo [" << bestAlgo << "] " << cached_algo.algoId << std::endl;
algoCache.store(gemm_cfg, cached_algo);
}
}
// D = alpha * (A * B) + beta * C
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle,
operationDesc,
static_cast<const void*>(&one), /* alpha */
A, /* A */
Adesc,
B, /* B */
Bdesc,
static_cast<const void*>(&beta), /* beta */
D, /* C */
Ddesc,
D, /* D */
Ddesc,
&cached_algo.algo.value(), /* algo */
workspace, /* workspace */
workspaceSize,
stream)); /* stream */
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
}
class userArgsManager {
public:
userArgsManager() {}
......@@ -1357,7 +1812,7 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
if (compute_stream_offset != -1) {
// Init hipblaslt handles (once, globally)
static std::once_flag init_flag;
static hipblasLtHandle_t hipblaslt_handles[1];
static hipblasLtHandle_t hipblaslt_handles[compute_num_streams];
std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);
handle = hipblaslt_handles[compute_stream_offset];
......
......@@ -134,6 +134,11 @@ void nvte_cublas_batchgemm_v2(const NVTETensor A, const NVTETensor B, NVTETensor
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream);
void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTETensor A_scales, const NVTETensor B_scales, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream);
#endif
#ifdef __cplusplus
......
......@@ -131,7 +131,7 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
void CheckInputTensor(const Tensor &t, const std::string &name) {
const DType type = t.dtype();
if (is_fp8_dtype(type)) {
if (is_fp8_dtype(type) || is_int8_dtype(type)) {
// FP8 input needs to have scale_inv
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor input ", name,
......
......@@ -38,6 +38,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_q
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.fp8 import int8_simulation_fp8, int8_simulation_fp8_tensorwise
tensorwise_int8_check = bool(int(os.getenv("NVTE_INT8_SIM_FP8_TENSORWISE_CHECK", "0")))
__all__ = [
"general_gemm",
"general_grouped_gemm",
......@@ -181,6 +182,47 @@ def general_gemm(
):
raise RuntimeError("GEMM with Float8BlockwiseQTensor requires GEMM_READY format")
if int8_simulation_fp8 and (isinstance(A, Float8TensorBase) or isinstance(B, Float8TensorBase)) and int8_simulation_fp8_tensorwise:
assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation"
assert bias is None, "Bias not supported with int8 simulation"
assert ub is None, "User buffer not supported with int8 simulation"
assert ub_type is None, "User buffer type not supported with int8 simulation"
assert extra_output is None, "Extra output not supported with int8 simulation"
assert not bulk_overlap, "Bulk overlap not supported with int8 simulation"
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
args = (
A,
transa, # transa
B,
transb, # transb
out,
quantization_params,
TE_DType[out_dtype] if out_dtype is not None else None,
bias,
bias_dtype,
gelu,
gelu_in,
grad, # grad
workspace,
workspace.shape[0],
accumulate,
True,
)
kwargs = {
"comm_overlap": ub,
"comm_type": ub_type,
"extra_output": extra_output,
"bulk_overlap": bulk_overlap,
}
out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs)
if debug_quantizer is not None:
out = debug_quantizer.process_gemm_output(out)
return out, bias_grad, gelu_input, extra_output
if int8_simulation_fp8 and (isinstance(A, Float8TensorBase) or isinstance(B, Float8TensorBase)):
assert not gelu, "GELU not supported with int8 simulation"
......@@ -195,12 +237,8 @@ def general_gemm(
if layout == "TN":
assert out_dtype is torch.bfloat16
out_shape = B._data.shape[:-1] + (A._data.shape[0], )
if int8_simulation_fp8_tensorwise:
x_int8, x_scales = B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv
w_int8, w_scales = A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv
else:
x_int8, x_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
y_int32 = tex.generic_gemm(
w_int8,
......@@ -220,22 +258,14 @@ def general_gemm(
False,
use_split_accumulator,
)[0]
if int8_simulation_fp8_tensorwise:
y = torch.empty_like(y_int32, device=y_int32.device, dtype=torch.bfloat16)
tensorwise_dequantize(x_scales, w_scales, y_int32, y)
else:
y = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
y = channelwise_dequantize_transB(x_scales, w_scales, y_int32)
return y.view(out_shape), None, None, None
elif layout == "NN":
assert out_dtype is torch.bfloat16
dx_shape = B._data.shape[:-1] + (A._data.shape[-1], )
if int8_simulation_fp8_tensorwise:
dy_int8, dy_scales = B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv
w_int8, w_scales = A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv
else:
dy_int8, dy_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
dy_int8, dy_scales = per_token_quant_fp8_to_int8(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
w_int8, w_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
dx_int32 = tex.generic_gemm(
w_int8,
......@@ -255,21 +285,13 @@ def general_gemm(
False,
use_split_accumulator,
)[0]
if int8_simulation_fp8_tensorwise:
dx = torch.empty_like(dx_int32, device=dx_int32.device, dtype=torch.bfloat16)
tensorwise_dequantize(dy_scales, w_scales, dx_int32, dx)
else:
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
dx = channelwise_dequantize(dy_scales, w_scales, dx_int32)
return dx.view(dx_shape), None, None, None
elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32
if int8_simulation_fp8_tensorwise:
dy_int8, dy_scales = B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv
x_int8, x_scales = A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv
else:
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
dw_int32 = tex.generic_gemm(
x_int8,
......@@ -291,29 +313,15 @@ def general_gemm(
)[0]
if out_dtype is torch.bfloat16:
if accumulate:
if int8_simulation_fp8_tensorwise:
tensorwise_dequantize_add(dy_scales, x_scales, dw_int32, out)
else:
channelwise_dequantize_transA_add(dy_scales, x_scales, dw_int32, out)
channelwise_dequantize_transA_add(dy_scales, x_scales, dw_int32, out)
else:
if int8_simulation_fp8_tensorwise:
out = torch.empty_like(dw_int32, device=dw_int32.device, dtype=torch.bfloat16)
tensorwise_dequantize(dy_scales, x_scales, dw_int32, out)
else:
out = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
out = channelwise_dequantize_transA(dy_scales, x_scales, dw_int32)
else:
if accumulate:
if int8_simulation_fp8_tensorwise:
tensorwise_dequantize_float_add(dy_scales, x_scales, dw_int32, out)
else:
channelwise_dequantize_transA_float_add(dy_scales, x_scales, dw_int32, out)
channelwise_dequantize_transA_float_add(dy_scales, x_scales, dw_int32, out)
else:
if int8_simulation_fp8_tensorwise:
out = torch.empty_like(dw_int32, device=dw_int32.device, dtype=torch.float32)
tensorwise_dequantize_float(dy_scales, x_scales, dw_int32, out)
else:
out = channelwise_dequantize_transA_float(dy_scales, x_scales, dw_int32)
out = channelwise_dequantize_transA_float(dy_scales, x_scales, dw_int32)
return out, None, None, None
else:
......@@ -464,6 +472,120 @@ def general_grouped_gemm(
else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
if int8_simulation_fp8 and (isinstance(A, Float8TensorBase) or isinstance(B, Float8TensorBase)) and int8_simulation_fp8_tensorwise:
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert not use_bias, "Bias not supported with int8 simulation groupgemm."
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32, "Out_dtype must be bfloat16 or float32 for int8 simulation"
if layout == "TN":
assert out_dtype is torch.bfloat16
qx_data_list, scales_x_list = [b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]) for b in B], [b._scale_inv for b in B]
w_data_list, scales_w_list = [a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]) for a in A], [a._scale_inv for a in A]
num_gemms = len(A)
qx_data = torch.stack(qx_data_list).contiguous()
w_data = torch.stack(w_data_list).contiguous()
scales_x = torch.stack(scales_x_list).contiguous()
scales_w = torch.stack(scales_w_list).contiguous()
out[0] = tex.tensorwise_int8_batchgemm(
w_data.view(-1, w_data.size(-1)),
transa,
qx_data.view(-1, qx_data.size(-1)),
transb,
scales_w,
scales_x,
out[0],
num_gemms,
None,
TE_DType[out_dtype],
bias,
bias_dtype,
gelu,
gelu_input[0],
grad, # grad
workspaces[0],
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)[0]
return out, bias, gelu_input
if layout == "NN":
assert out_dtype is torch.bfloat16
qdout_data_list, scales_dout_list = [b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]) for b in B], [b._scale_inv for b in B]
w_data_list, scales_w_list = [a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]) for a in A], [a._scale_inv for a in A]
num_gemms = len(A)
qdout_data = torch.stack(qdout_data_list).contiguous()
w_data = torch.stack(w_data_list).contiguous()
scales_dout = torch.stack(scales_dout_list).contiguous()
scales_w = torch.stack(scales_w_list).contiguous()
out[0] = tex.tensorwise_int8_batchgemm(
w_data.view(-1, w_data.size(-1)),
transa,
qdout_data.view(-1, qdout_data.size(-1)),
transb,
scales_w,
scales_dout,
out[0],
num_gemms,
None,
TE_DType[out_dtype],
bias,
bias_dtype,
gelu,
gelu_input[0],
grad, # grad
workspaces[0],
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)[0]
return out, bias, gelu_input
elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32
qdout_data_list, scales_dout_list = [b._data.view(dtype=TE_DType_To_Torch[b._fp8_dtype]) for b in B], [b._scale_inv for b in B]
qx_data_list, scales_x_list = [a._data.view(dtype=TE_DType_To_Torch[a._fp8_dtype]) for a in A], [a._scale_inv for a in A]
num_gemms = len(A)
qdout_data = torch.stack(qdout_data_list).contiguous()
qx_data = torch.stack(qx_data_list).contiguous()
scales_dout = torch.stack(scales_dout_list).contiguous()
scales_x = torch.stack(scales_x_list).contiguous()
dw = torch.stack(out).contiguous()
dw = tex.tensorwise_int8_batchgemm(
qx_data.view(-1, qx_data.size(-1)),
transa,
qdout_data.view(-1, qdout_data.size(-1)),
transb,
scales_x,
scales_dout,
dw.view(-1, dw.size(-1)),
num_gemms,
None,
TE_DType[out_dtype],
bias,
bias_dtype,
gelu,
gelu_input[0],
grad, # grad
workspaces[0],
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
)
for i in range(num_gemms):
out[i].copy_(dw[i])
return out, bias, gelu_input
if int8_simulation_fp8 and (isinstance(A[0], Float8TensorBase) or isinstance(B[0], Float8TensorBase)):
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
......
......@@ -88,14 +88,6 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
std::optional<CommOverlapType> comm_type = std::nullopt,
MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);
std::vector<py::object> generic_batchgemm(py::handle A, bool transa, py::handle B, bool transb, py::object D, int batch_count,
py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr,
std::optional<CommOverlapType> comm_type = std::nullopt,
MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);
void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, DType B_type, std::vector<int64_t> B_scaling_mode,
......@@ -113,6 +105,22 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count);
#ifdef __HIP_PLATFORM_AMD__
std::vector<py::object> generic_batchgemm(py::handle A, bool transa, py::handle B, bool transb, py::object D, int batch_count,
py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr,
std::optional<CommOverlapType> comm_type = std::nullopt,
MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);
std::vector<py::object> tensorwise_int8_batchgemm(py::handle A, bool transa, py::handle B, bool transb, py::handle A_scales, py::handle B_scales, py::object D, int batch_count,
py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr,
std::optional<CommOverlapType> comm_type = std::nullopt,
MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);
void te_batchgemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type,
......
......@@ -271,138 +271,6 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
return out;
}
std::vector<py::object> generic_batchgemm(py::handle A, bool transa, py::handle B, bool transb, py::object D, int batch_count,
py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, CommOverlapCore* comm_overlap,
std::optional<CommOverlapType> comm_type, MaybeTensor extra_output,
bool bulk_overlap) {
// Input tensors
NVTE_CHECK(!A.is_none(), "Tensor A has not been provided");
NVTE_CHECK(!B.is_none(), "Tensor B has not been provided");
auto none = py::none();
TensorWrapper A_tensor = makeTransformerEngineTensor(A, none);
TensorWrapper B_tensor = makeTransformerEngineTensor(B, none);
const bool low_precision =
detail::is_low_precision(A_tensor.dtype()) || detail::is_low_precision(B_tensor.dtype());
// Check tensor dimensions
const auto& A_shape = A_tensor.shape();
const auto& B_shape = B_tensor.shape();
const auto& D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb);
NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension");
NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension");
// Output tensor
TensorWrapper D_tensor;
if (D.is_none()) {
NVTE_ERROR("generic batchgemm D must be not None.");
} else {
D_tensor = makeTransformerEngineTensor(D, quantizer);
if (out_dtype) {
NVTE_CHECK(*out_dtype == D_tensor.dtype(), "GEMM output has invalid dtype (expected ",
static_cast<int>(*out_dtype), ", found ", static_cast<int>(D_tensor.dtype()), ")");
}
}
// Bias tensor
TensorWrapper bias_tensor;
MaybeTensor bias_grad = std::nullopt;
if (bias.has_value()) {
if (grad) {
auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA);
bias_grad = at::empty({static_cast<int64_t>(B_shape.data[B_shape.ndim - 1])}, opts);
bias_tensor = makeTransformerEngineTensor(*bias_grad);
} else {
if (!bias->is_contiguous()) {
bias = bias->contiguous();
}
bias_tensor = makeTransformerEngineTensor(*bias);
}
}
// Activation input tensor
MaybeTensor pre_gelu_out = std::nullopt;
DType gelu_type = low_precision ? bias_type : D_tensor.dtype();
if (gelu) {
if (!grad) {
auto dtype = GetATenDType(gelu_type);
auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA);
std::vector<int64_t> torch_shape;
for (auto v : D_shape) {
torch_shape.push_back(v);
}
pre_gelu_out = at::empty(torch_shape, opts);
} else {
if (gelu_in.has_value()) {
pre_gelu_out = *gelu_in;
}
}
}
const auto gelu_shape = gelu ? D_shape : std::vector<size_t>{0};
auto te_pre_gelu_out =
makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type);
// Workspace
auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
std::vector<size_t>{workspaceSize}, DType::kByte);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const int device_id = at::cuda::current_device();
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
auto main_stream = at::cuda::getCurrentCUDAStream();
if (A_tensor.numel() != 0 && B_tensor.numel() != 0) {
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa)));
swizzled_scale_inverses_list.emplace_back(
std::move(swizzle_scaling_factors(B_tensor, !transb)));
if (comm_overlap) {
NVTE_ERROR("generic batchgemm not surpport comm_overlap.");
} else {
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_batchgemm_v2(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(),
te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(),
accumulate, use_split_accumulator, num_math_sms, batch_count, main_stream);
});
}
} else {
if (D_tensor.numel() != 0 && !accumulate) {
D_tensor.zero_(main_stream);
}
if (bias.has_value()) {
if (bias->numel() != 0 && grad) {
bias_grad->zero_();
}
}
}
// Pack outputs
std::vector<py::object> out;
out.emplace_back(std::move(D));
out.emplace_back(py::cast(bias_grad));
if (gelu && !grad) {
out.emplace_back(py::cast(*pre_gelu_out));
} else {
out.emplace_back(py::none());
}
if (extra_output.has_value()) {
out.emplace_back(py::cast(extra_output));
} else {
out.emplace_back(py::none());
}
return out;
}
void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, DType B_type, std::vector<int64_t> B_scaling_mode,
......@@ -586,6 +454,274 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
}
#ifdef USE_ROCM
std::vector<py::object> generic_batchgemm(py::handle A, bool transa, py::handle B, bool transb, py::object D, int batch_count,
py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, CommOverlapCore* comm_overlap,
std::optional<CommOverlapType> comm_type, MaybeTensor extra_output,
bool bulk_overlap) {
// Input tensors
NVTE_CHECK(!A.is_none(), "Tensor A has not been provided");
NVTE_CHECK(!B.is_none(), "Tensor B has not been provided");
auto none = py::none();
TensorWrapper A_tensor = makeTransformerEngineTensor(A, none);
TensorWrapper B_tensor = makeTransformerEngineTensor(B, none);
const bool low_precision =
detail::is_low_precision(A_tensor.dtype()) || detail::is_low_precision(B_tensor.dtype());
// Check tensor dimensions
const auto& A_shape = A_tensor.shape();
const auto& B_shape = B_tensor.shape();
const auto& D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb);
NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension");
NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension");
// Output tensor
TensorWrapper D_tensor;
if (D.is_none()) {
NVTE_ERROR("generic batchgemm D must be not None.");
} else {
D_tensor = makeTransformerEngineTensor(D, quantizer);
if (out_dtype) {
NVTE_CHECK(*out_dtype == D_tensor.dtype(), "GEMM output has invalid dtype (expected ",
static_cast<int>(*out_dtype), ", found ", static_cast<int>(D_tensor.dtype()), ")");
}
}
// Bias tensor
TensorWrapper bias_tensor;
MaybeTensor bias_grad = std::nullopt;
if (bias.has_value()) {
if (grad) {
auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA);
bias_grad = at::empty({static_cast<int64_t>(B_shape.data[B_shape.ndim - 1])}, opts);
bias_tensor = makeTransformerEngineTensor(*bias_grad);
} else {
if (!bias->is_contiguous()) {
bias = bias->contiguous();
}
bias_tensor = makeTransformerEngineTensor(*bias);
}
}
// Activation input tensor
MaybeTensor pre_gelu_out = std::nullopt;
DType gelu_type = low_precision ? bias_type : D_tensor.dtype();
if (gelu) {
if (!grad) {
auto dtype = GetATenDType(gelu_type);
auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA);
std::vector<int64_t> torch_shape;
for (auto v : D_shape) {
torch_shape.push_back(v);
}
pre_gelu_out = at::empty(torch_shape, opts);
} else {
if (gelu_in.has_value()) {
pre_gelu_out = *gelu_in;
}
}
}
const auto gelu_shape = gelu ? D_shape : std::vector<size_t>{0};
auto te_pre_gelu_out =
makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type);
// Workspace
auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
std::vector<size_t>{workspaceSize}, DType::kByte);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const int device_id = at::cuda::current_device();
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
auto main_stream = at::cuda::getCurrentCUDAStream();
if (A_tensor.numel() != 0 && B_tensor.numel() != 0) {
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa)));
swizzled_scale_inverses_list.emplace_back(
std::move(swizzle_scaling_factors(B_tensor, !transb)));
if (comm_overlap) {
NVTE_ERROR("generic batchgemm not surpport comm_overlap.");
} else {
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_batchgemm_v2(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(),
te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(),
accumulate, use_split_accumulator, num_math_sms, batch_count, main_stream);
});
}
} else {
if (D_tensor.numel() != 0 && !accumulate) {
D_tensor.zero_(main_stream);
}
if (bias.has_value()) {
if (bias->numel() != 0 && grad) {
bias_grad->zero_();
}
}
}
// Pack outputs
std::vector<py::object> out;
out.emplace_back(std::move(D));
out.emplace_back(py::cast(bias_grad));
if (gelu && !grad) {
out.emplace_back(py::cast(*pre_gelu_out));
} else {
out.emplace_back(py::none());
}
if (extra_output.has_value()) {
out.emplace_back(py::cast(extra_output));
} else {
out.emplace_back(py::none());
}
return out;
}
std::vector<py::object> tensorwise_int8_batchgemm(py::handle A, bool transa, py::handle B, bool transb, py::handle A_scales, py::handle B_scales, py::object D, int batch_count,
py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, CommOverlapCore* comm_overlap,
std::optional<CommOverlapType> comm_type, MaybeTensor extra_output,
bool bulk_overlap) {
// Input tensors
NVTE_CHECK(!A.is_none(), "Tensor A has not been provided");
NVTE_CHECK(!B.is_none(), "Tensor B has not been provided");
NVTE_CHECK(!A_scales.is_none(), "Tensor A has not been provided");
NVTE_CHECK(!B_scales.is_none(), "Tensor B has not been provided");
auto none = py::none();
TensorWrapper A_tensor = makeTransformerEngineTensor(A, none);
TensorWrapper B_tensor = makeTransformerEngineTensor(B, none);
TensorWrapper A_scales_tensor = makeTransformerEngineTensor(A_scales, none);
TensorWrapper B_scales_tensor = makeTransformerEngineTensor(B_scales, none);
const bool low_precision =
detail::is_low_precision(A_tensor.dtype()) || detail::is_low_precision(B_tensor.dtype());
// Check tensor dimensions
const auto& A_shape = A_tensor.shape();
const auto& B_shape = B_tensor.shape();
const auto& D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb);
NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension");
NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension");
// Output tensor
TensorWrapper D_tensor;
if (D.is_none()) {
NVTE_ERROR("tensorwise int8 batchgemm D must be not None.");
} else {
D_tensor = makeTransformerEngineTensor(D, quantizer);
if (out_dtype) {
NVTE_CHECK(*out_dtype == D_tensor.dtype(), "GEMM output has invalid dtype (expected ",
static_cast<int>(*out_dtype), ", found ", static_cast<int>(D_tensor.dtype()), ")");
}
}
// Bias tensor
TensorWrapper bias_tensor;
MaybeTensor bias_grad = std::nullopt;
if (bias.has_value()) {
if (grad) {
auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA);
bias_grad = at::empty({static_cast<int64_t>(B_shape.data[B_shape.ndim - 1])}, opts);
bias_tensor = makeTransformerEngineTensor(*bias_grad);
} else {
if (!bias->is_contiguous()) {
bias = bias->contiguous();
}
bias_tensor = makeTransformerEngineTensor(*bias);
}
}
// Activation input tensor
MaybeTensor pre_gelu_out = std::nullopt;
DType gelu_type = low_precision ? bias_type : D_tensor.dtype();
if (gelu) {
if (!grad) {
auto dtype = GetATenDType(gelu_type);
auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA);
std::vector<int64_t> torch_shape;
for (auto v : D_shape) {
torch_shape.push_back(v);
}
pre_gelu_out = at::empty(torch_shape, opts);
} else {
if (gelu_in.has_value()) {
pre_gelu_out = *gelu_in;
}
}
}
const auto gelu_shape = gelu ? D_shape : std::vector<size_t>{0};
auto te_pre_gelu_out =
makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type);
// Workspace
auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
std::vector<size_t>{workspaceSize}, DType::kByte);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const int device_id = at::cuda::current_device();
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
auto main_stream = at::cuda::getCurrentCUDAStream();
if (A_tensor.numel() != 0 && B_tensor.numel() != 0) {
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa)));
swizzled_scale_inverses_list.emplace_back(
std::move(swizzle_scaling_factors(B_tensor, !transb)));
if (comm_overlap) {
NVTE_ERROR("generic batchgemm not surpport comm_overlap.");
} else {
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_batchgemm_v3(A_tensor.data(), B_tensor.data(), A_scales_tensor.data(), B_scales_tensor.data(), D_tensor.data(), bias_tensor.data(),
te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(),
accumulate, use_split_accumulator, num_math_sms, batch_count, main_stream);
});
}
} else {
if (D_tensor.numel() != 0 && !accumulate) {
D_tensor.zero_(main_stream);
}
if (bias.has_value()) {
if (bias->numel() != 0 && grad) {
bias_grad->zero_();
}
}
}
// Pack outputs
std::vector<py::object> out;
out.emplace_back(std::move(D));
out.emplace_back(py::cast(bias_grad));
if (gelu && !grad) {
out.emplace_back(py::cast(*pre_gelu_out));
} else {
out.emplace_back(py::none());
}
if (extra_output.has_value()) {
out.emplace_back(py::cast(extra_output));
} else {
out.emplace_back(py::none());
}
return out;
}
void te_batchgemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int A_offset,
transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type,
......
......@@ -110,13 +110,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"),
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false);
m.def("generic_batchgemm", transformer_engine::pytorch::generic_batchgemm, "Compute Batched GEMM (matrix-matrix multiply)",
py::arg("A"), py::arg("transA"), py::arg("B"), py::arg("transB"), py::arg("D"), py::arg("batchcount"),
py::arg("quantizer"), py::arg("output_dtype"), py::arg("bias"), py::arg("bias_type"),
py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"),
py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"),
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false);
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"),
......@@ -213,6 +206,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm,
"Grouped GEMM");
#ifdef USE_ROCM
m.def("generic_batchgemm", transformer_engine::pytorch::generic_batchgemm, "Compute Batched GEMM (matrix-matrix multiply)",
py::arg("A"), py::arg("transA"), py::arg("B"), py::arg("transB"), py::arg("D"), py::arg("batchcount"),
py::arg("quantizer"), py::arg("output_dtype"), py::arg("bias"), py::arg("bias_type"),
py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"),
py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"),
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false);
m.def("tensorwise_int8_batchgemm", transformer_engine::pytorch::tensorwise_int8_batchgemm, "Compute Tensorwise Int8 Batched GEMM (matrix-matrix multiply)",
py::arg("A"), py::arg("transA"), py::arg("B"), py::arg("transB"), py::arg("A_scales"), py::arg("B_scales"), py::arg("D"), py::arg("batchcount"),
py::arg("quantizer"), py::arg("output_dtype"), py::arg("bias"), py::arg("bias_type"),
py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"),
py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"),
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false);
m.def("te_batchgemm_ts", &transformer_engine::pytorch::te_batchgemm_ts, "Batched GEMM"); /// rocblas
#endif
m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment