Commit 9406ff31 authored by yuguo's avatar yuguo
Browse files

[DCU] surpport NVTE_USE_HIPBLASLT_GROUPEDGEMM

parent bc2d9697
......@@ -903,6 +903,66 @@ void nvte_cublas_handle_init() { auto _ = cublasHandleManager::Instance().GetHan
#endif
#ifdef __HIP_PLATFORM_AMD__
void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out,
const int num_gemms, bool transa, bool transb, bool grad,
NVTETensor *workspace, bool accumulate,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
NVTE_API_CALL(nvte_grouped_gemm);
using namespace transformer_engine;
std::vector<const Tensor*> inputA;
std::vector<const Tensor*> inputB;
std::vector<Tensor*> outputD;
std::vector<const Tensor*> biasTensor;
std::vector<Tensor*> outputGelu;
std::vector<int64_t> m;
std::vector<int64_t> n;
std::vector<int64_t> k;
std::vector<int64_t> b;
for (int i = 0; i < num_gemms; i++) {
inputA.push_back(convertNVTETensorCheck(A[i]));
inputB.push_back(convertNVTETensorCheck(B[i]));
outputD.push_back(convertNVTETensorCheck(D[i]));
biasTensor.push_back(convertNVTETensorCheck(bias[i]));
outputGelu.push_back(convertNVTETensorCheck(pre_gelu_out[i]));
b.push_back(1);
size_t A0 = inputA[i]->flat_first_dim();
size_t A1 = inputA[i]->flat_last_dim();
size_t B0 = inputB[i]->flat_first_dim();
size_t B1 = inputB[i]->flat_last_dim();
if (transa) {
m.push_back(A0);
k.push_back(A1);
} else {
m.push_back(A1);
k.push_back(A0);
}
if (transb) {
n.push_back(B1);
} else {
n.push_back(B0);
}
}
Tensor *wspace = convertNVTETensorCheck(workspace[0]);
if ((biasTensor[0]->data.dptr != nullptr) || (outputGelu[0]->data.dptr != nullptr)) {
NVTE_ERROR("MOE nvte_grouped_gemm not surpport bias or gelu.");
}
hipblaslt_goupedgemm(inputA, inputB, outputD, m, n, k, b,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
wspace->data.dptr, wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count, stream);
}
void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out,
......
......@@ -22,6 +22,7 @@
#define ROCBLAS_BETA_FEATURES_API
#include <rocblas/rocblas.h>
#include <hipcub/hipcub.hpp>
#include <hipblaslt/hipblaslt-ext.hpp>
#endif
#include <iostream>
#include <cstdlib>
......@@ -50,6 +51,10 @@ static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) {
return HIP_R_8F_E4M3;
case DType::kFloat8E5M2:
return HIP_R_8F_E5M2;
case DType::kInt8:
return HIP_R_8I;
case DType::kInt32:
return HIP_R_32I;
default:
NVTE_ERROR("Invalid type");
}
......@@ -1367,6 +1372,147 @@ void hipblaslt_gemm(const Tensor *inputA,
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
}
class userArgsManager {
public:
userArgsManager() {}
~userArgsManager() {
// Release all userArgs when the manager is destroyed
for (auto& device_pair : userArgs_map_) {
hipFree(device_pair.second); // Only one userArgs per device
}
}
// Get a userArgs for the given device (creates if necessary)
hipblaslt_ext::UserArguments* get(int device_id, size_t size) {
std::lock_guard<std::mutex> lock(mutex_);
// Check if the userArgs for this device exists
auto device_it = userArgs_map_.find(device_id);
if (device_it != userArgs_map_.end()) {
return device_it->second;
}
// Create a new userArgs for this device if it doesn't exist
hipblaslt_ext::UserArguments* userArgs;
NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, size * sizeof(hipblaslt_ext::UserArguments)));
// Store the userArgs in the map for this device
userArgs_map_[device_id] = userArgs;
return userArgs;
}
private:
std::unordered_map<int, hipblaslt_ext::UserArguments*> userArgs_map_; // Map from device_id to hipblasHandle
std::mutex mutex_;
};
// Define a static userArgs manager
// static userArgsManager UAManager;
void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB, std::vector<Tensor*>& outputD,
std::vector<int64_t>& m, std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b, hipblasOperation_t transa, hipblasOperation_t transb,
void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
int math_sm_count, hipStream_t stream, int compute_stream_offset = 0) {
// Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
// int device_id;
// hipGetDevice(&device_id);
// hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
hipblasLtHandle_t handle = nullptr;
if (compute_stream_offset != -1) {
// Init hipblaslt handles (once, globally)
static std::once_flag init_flag;
static hipblasLtHandle_t hipblaslt_handles[1];
std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);
handle = hipblaslt_handles[compute_stream_offset];
}
const hipDataType A_type = get_hipblaslt_dtype(inputA[0]->data.dtype);
const hipDataType B_type = get_hipblaslt_dtype(inputB[0]->data.dtype);
const hipDataType D_type = get_hipblaslt_dtype(outputD[0]->data.dtype);
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
float one = 1.0;
float zero = 0.0;
float beta = (accumulate) ? one : zero;
int int_one = 1;
int int_zero = 0;
int int_beta = int_zero;
bool use_int8 = false;
if ((A_type == HIP_R_8I) && (B_type == HIP_R_8I) && (D_type == HIP_R_32I)) {
NVTE_CHECK(!accumulate, "Int8 gemm not support accumulate.");
use_int8 = true;
computeType = HIPBLAS_COMPUTE_32I;
}
hipblaslt_ext::GemmPreference gemmPref;
gemmPref.setMaxWorkspaceBytes(workspaceSize);
hipblaslt_ext::GroupedGemm groupedgemm(handle,
transa,
transb,
A_type,
B_type,
D_type,
D_type,
computeType);
std::vector<hipblaslt_ext::GemmEpilogue> epilogue{
hipblaslt_ext::
GemmEpilogue()}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
std::vector<hipblaslt_ext::GemmInputs> inputs(m.size());
for(int i = 0; i < m.size(); i++)
{
inputs[i].a = inputA[i]->data.dptr;
inputs[i].b = inputB[i]->data.dptr;
inputs[i].c = outputD[i]->data.dptr;
inputs[i].d = outputD[i]->data.dptr;
inputs[i].alpha = use_int8 ? static_cast<void*>(&int_one) : static_cast<void*>(&one);
inputs[i].beta = use_int8 ? static_cast<void*>(&int_beta) : static_cast<void*>(&beta);
}
// hipblaslt_ext::GemmEpilogue supports broadcasting
groupedgemm.setProblem(m, n, k, b, epilogue, inputs);
const int request_solutions = 1;
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
NVTE_CHECK_HIPBLASLT(
groupedgemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult));
if(heuristicResult.empty())
{
std::cerr << "No valid solution found!" << std::endl;
return;
}
// Get the default values from the grouepdgemm object
// groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
// Copy them to device memory
// hipblaslt_ext::UserArguments* d_userArgs;
// NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
// NVTE_CHECK_CUDA(hipMemcpyAsync(d_userArgs,
// userArgs,
// m.size() * sizeof(hipblaslt_ext::UserArguments),
// hipMemcpyHostToDevice, stream));
// Make sure to initialize everytime the algo changes
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
// NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
// NVTE_CHECK_CUDA(hipFree(userArgs));
}
#endif //USE_HIPBLASLT
#ifdef USE_ROCBLAS // Use rocblas + kernel, no fusion
......
......@@ -564,13 +564,24 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
wrappers.emplace_back(std::move(wsp));
}
// For now, we only have multi-stream cublas backend.
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(),
te_bias_vector.data(), te_pre_gelu_out_vector.data(),
te_A_vector.size(), transa, transb, grad,
te_workspace_vector.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
});
const char *NVTE_USE_HIPBLASLT_GROUPEDGEMM = std::getenv("NVTE_USE_HIPBLASLT_GROUPEDGEMM");
if(NVTE_USE_HIPBLASLT_GROUPEDGEMM != nullptr && NVTE_USE_HIPBLASLT_GROUPEDGEMM[0] == '1'){
NVTE_SCOPED_GIL_RELEASE({
nvte_grouped_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(),
te_bias_vector.data(), te_pre_gelu_out_vector.data(),
te_A_vector.size(), transa, transb, grad,
te_workspace_vector.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
});
} else {
NVTE_SCOPED_GIL_RELEASE({
nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(),
te_bias_vector.data(), te_pre_gelu_out_vector.data(),
te_A_vector.size(), transa, transb, grad,
te_workspace_vector.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
});
}
return bias;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment