Commit 403db136 authored by yuguo's avatar yuguo
Browse files
parents c3a36d7e 76023d21
...@@ -504,3 +504,57 @@ for i in range(20): ...@@ -504,3 +504,57 @@ for i in range(20):
# dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32) # dw = channelwise_dequantize_transB(dy_scales, x_scales, dw_int32)
torch.cuda.synchronize() torch.cuda.synchronize()
end = time.time() end = time.time()
# bacth gemm wgrad
m = 32
k = 32
n = 32
b = 4
transa = False
transb = True
dy_int8 = (torch.randn((b, m, n), device=device)).to(dtype=torch.int8)
x_int8 = (torch.randn((b, m, k), device=device)).to(dtype=torch.int8)
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
int32_dw_list = []
for i in range(b):
int32_dw = torch._int_mm(dy_int8[i].t(), x_int8[i])
# bf16_dw = torch.matmul(dy_int8[i].t(), x_int8[i])
int32_dw_list.append(int32_dw)
batched_int32_dw = torch.stack(int32_dw_list)
# print("batched_int32_dw.shape: ", batched_int32_dw.shape)
# print("batched_int32_dw: ", batched_int32_dw)
out_dtype = torch.int32
out = torch.empty((b, n, k), dtype=out_dtype, device=device)
te_dw = tex.generic_batchgemm(
x_int8.view(-1, x_int8.size(-1)),
transa,
dy_int8.view(-1, dy_int8.size(-1)),
transb,
out.view(-1, out.size(-1)),
b,
out_quantizer,
TE_DType[out_dtype],
bias,
bias_dtype,
use_gelu,
aux_tensor,
use_grad,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
)[0]
# print("te_dw.shape: ", te_dw.view(b, -1, te_dw.size(-1)).shape)
# print("te_dw: ", te_dw.view(b, -1, te_dw.size(-1)))
torch.testing.assert_close(te_dw, batched_int32_dw, atol=1e-5, rtol=1e-5)
...@@ -970,4 +970,74 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, ...@@ -970,4 +970,74 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
batch_count, batch_count,
stream); stream);
} }
// add for batchgemm
void nvte_cublas_batchgemm_v2(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_batchgemm_v2);
using namespace transformer_engine;
const Tensor *inputA = convertNVTETensorCheck(A);
const Tensor *inputB = convertNVTETensorCheck(B);
Tensor *outputD = convertNVTETensor(D);
const Tensor *biasTensor = convertNVTETensor(bias);
Tensor *outputGelu = convertNVTETensor(pre_gelu_out);
Tensor *wspace = convertNVTETensor(workspace);
if ((biasTensor->data.dptr != nullptr) || (outputGelu->data.dptr != nullptr)) {
NVTE_ERROR("MOE batchgemm not surpport bias or gelu.");
}
int m, n, k;
if (!transa && transb) {
// for NT
m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
} else if(transa && !transb){
// for TN
m = transa ? inputA->data.shape[0]/batch_count: inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count;
} else if(!transa && !transb){
// for NN
m = transa ? inputA->data.shape[0]/batch_count : inputA->data.shape[1];
k = transa ? inputA->data.shape[1] : inputA->data.shape[0]/batch_count;
n = transb ? inputB->data.shape[1] : inputB->data.shape[0]/batch_count; }
int lda, ldb, ldd;
if (transa && !transb) { // TN
lda = k;
ldb = k;
ldd = m;
} else if (!transa && !transb) { // NN
lda = m;
ldb = k;
ldd = m;
} else if (!transa && transb) { // NT
lda = m;
ldb = n;
ldd = m;
} else { // TT
NVTE_ERROR("TT layout not allowed.");
}
hipblas_batchgemm(inputA,
inputB,
outputD,
biasTensor,
outputGelu,
m, n, k,
lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad, wspace->data.dptr,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
0,
0,
false,
nullptr,
batch_count,
stream);
}
#endif #endif
\ No newline at end of file
...@@ -182,6 +182,16 @@ void hipblas_batchgemm(const Tensor *inputA, ...@@ -182,6 +182,16 @@ void hipblas_batchgemm(const Tensor *inputA,
float one = 1.0f; float one = 1.0f;
float zero = 0.0f; float zero = 0.0f;
float beta = accumulate ? one : zero; float beta = accumulate ? one : zero;
int int_one = 1;
int int_zero = 0;
int int_beta = int_zero;
bool use_int8 = false;
if ((A_type == HIPBLAS_R_8I) && (B_type == HIPBLAS_R_8I) && (D_type == HIPBLAS_R_32I)) {
NVTE_CHECK(!accumulate, "Int8 gemm not support accumulate.");
use_int8 = true;
computeType = HIPBLAS_R_32I;
}
hipblasSetStream(handle, stream); hipblasSetStream(handle, stream);
// execute multiply // execute multiply
...@@ -197,7 +207,7 @@ void hipblas_batchgemm(const Tensor *inputA, ...@@ -197,7 +207,7 @@ void hipblas_batchgemm(const Tensor *inputA,
m, m,
n, n,
k, k,
static_cast<const void*>(&one), use_int8 ? static_cast<const void*>(&int_one) : static_cast<const void*>(&one),
A, A,
A_type, A_type,
lda, lda,
...@@ -206,7 +216,7 @@ void hipblas_batchgemm(const Tensor *inputA, ...@@ -206,7 +216,7 @@ void hipblas_batchgemm(const Tensor *inputA,
B_type, B_type,
ldb, ldb,
strideB, strideB,
static_cast<const void*>(&beta), use_int8 ? static_cast<const void*>(&int_beta) : static_cast<const void*>(&beta),
D, D,
D_type, D_type,
ldd, ldd,
......
...@@ -122,6 +122,11 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D, ...@@ -122,6 +122,11 @@ void nvte_cublas_batchgemm(const NVTETensor A, const NVTETensor B, NVTETensor D,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator, NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream); int math_sm_count, int batch_count, cudaStream_t stream);
void nvte_cublas_batchgemm_v2(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias,
NVTETensor pre_gelu_out, bool transa, bool transb, bool grad,
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, int batch_count, cudaStream_t stream);
#endif #endif
#ifdef __cplusplus #ifdef __cplusplus
......
...@@ -88,6 +88,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -88,6 +88,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
std::optional<CommOverlapType> comm_type = std::nullopt, std::optional<CommOverlapType> comm_type = std::nullopt,
MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false); MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);
std::vector<py::object> generic_batchgemm(py::handle A, bool transa, py::handle B, bool transb, py::object D, int batch_count,
py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr,
std::optional<CommOverlapType> comm_type = std::nullopt,
MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);
void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B, std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, DType B_type, std::vector<int64_t> B_scaling_mode, at::Tensor B_scale_inverse, DType B_type, std::vector<int64_t> B_scaling_mode,
......
...@@ -271,6 +271,138 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans ...@@ -271,6 +271,138 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
return out; return out;
} }
std::vector<py::object> generic_batchgemm(py::handle A, bool transa, py::handle B, bool transb, py::object D, int batch_count,
py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
at::Tensor workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, CommOverlapCore* comm_overlap,
std::optional<CommOverlapType> comm_type, MaybeTensor extra_output,
bool bulk_overlap) {
// Input tensors
NVTE_CHECK(!A.is_none(), "Tensor A has not been provided");
NVTE_CHECK(!B.is_none(), "Tensor B has not been provided");
auto none = py::none();
TensorWrapper A_tensor = makeTransformerEngineTensor(A, none);
TensorWrapper B_tensor = makeTransformerEngineTensor(B, none);
const bool low_precision =
detail::is_low_precision(A_tensor.dtype()) || detail::is_low_precision(B_tensor.dtype());
// Check tensor dimensions
const auto& A_shape = A_tensor.shape();
const auto& B_shape = B_tensor.shape();
const auto& D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb);
NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension");
NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension");
// Output tensor
TensorWrapper D_tensor;
if (D.is_none()) {
NVTE_ERROR("generic batchgemm D must be not None.");
} else {
D_tensor = makeTransformerEngineTensor(D, quantizer);
if (out_dtype) {
NVTE_CHECK(*out_dtype == D_tensor.dtype(), "GEMM output has invalid dtype (expected ",
static_cast<int>(*out_dtype), ", found ", static_cast<int>(D_tensor.dtype()), ")");
}
}
// Bias tensor
TensorWrapper bias_tensor;
MaybeTensor bias_grad = std::nullopt;
if (bias.has_value()) {
if (grad) {
auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA);
bias_grad = at::empty({static_cast<int64_t>(B_shape.data[B_shape.ndim - 1])}, opts);
bias_tensor = makeTransformerEngineTensor(*bias_grad);
} else {
if (!bias->is_contiguous()) {
bias = bias->contiguous();
}
bias_tensor = makeTransformerEngineTensor(*bias);
}
}
// Activation input tensor
MaybeTensor pre_gelu_out = std::nullopt;
DType gelu_type = low_precision ? bias_type : D_tensor.dtype();
if (gelu) {
if (!grad) {
auto dtype = GetATenDType(gelu_type);
auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA);
std::vector<int64_t> torch_shape;
for (auto v : D_shape) {
torch_shape.push_back(v);
}
pre_gelu_out = at::empty(torch_shape, opts);
} else {
if (gelu_in.has_value()) {
pre_gelu_out = *gelu_in;
}
}
}
const auto gelu_shape = gelu ? D_shape : std::vector<size_t>{0};
auto te_pre_gelu_out =
makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type);
// Workspace
auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(),
std::vector<size_t>{workspaceSize}, DType::kByte);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
const int device_id = at::cuda::current_device();
const int sm_count = transformer_engine::cuda::sm_count(device_id);
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);
// Keep the swizzled scaling factor tensors alive during the GEMM.
std::vector<std::optional<at::Tensor>> swizzled_scale_inverses_list;
auto main_stream = at::cuda::getCurrentCUDAStream();
if (A_tensor.numel() != 0 && B_tensor.numel() != 0) {
// Optionally swizzle the scaling factors
swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa)));
swizzled_scale_inverses_list.emplace_back(
std::move(swizzle_scaling_factors(B_tensor, !transb)));
if (comm_overlap) {
NVTE_ERROR("generic batchgemm not surpport comm_overlap.");
} else {
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_batchgemm_v2(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(),
te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(),
accumulate, use_split_accumulator, num_math_sms, batch_count, main_stream);
});
}
} else {
if (D_tensor.numel() != 0 && !accumulate) {
D_tensor.zero_(main_stream);
}
if (bias.has_value()) {
if (bias->numel() != 0 && grad) {
bias_grad->zero_();
}
}
}
// Pack outputs
std::vector<py::object> out;
out.emplace_back(std::move(D));
out.emplace_back(py::cast(bias_grad));
if (gelu && !grad) {
out.emplace_back(py::cast(*pre_gelu_out));
} else {
out.emplace_back(py::none());
}
if (extra_output.has_value()) {
out.emplace_back(py::cast(extra_output));
} else {
out.emplace_back(py::none());
}
return out;
}
void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B, std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, DType B_type, std::vector<int64_t> B_scaling_mode, at::Tensor B_scale_inverse, DType B_type, std::vector<int64_t> B_scaling_mode,
......
...@@ -110,6 +110,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -110,6 +110,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"), py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"),
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false); py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false);
m.def("generic_batchgemm", transformer_engine::pytorch::generic_batchgemm, "Compute Batched GEMM (matrix-matrix multiply)",
py::arg("A"), py::arg("transA"), py::arg("B"), py::arg("transB"), py::arg("D"), py::arg("batchcount"),
py::arg("quantizer"), py::arg("output_dtype"), py::arg("bias"), py::arg("bias_type"),
py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"),
py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"),
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false);
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
py::arg("quantizer")); py::arg("quantizer"));
m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"), m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"),
......
...@@ -153,6 +153,8 @@ def _per_token_quant_fp8_to_int8( ...@@ -153,6 +153,8 @@ def _per_token_quant_fp8_to_int8(
def per_token_quant_fp8_to_int8(x, fp8_scale_inv, inplace=False): def per_token_quant_fp8_to_int8(x, fp8_scale_inv, inplace=False):
assert x.is_contiguous()
x = x.view(-1, x.shape[-1])
M = x.numel() // x.shape[-1] M = x.numel() // x.shape[-1]
N = x.shape[-1] N = x.shape[-1]
if inplace: if inplace:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment