Commit 0a90777e authored by wenjh's avatar wenjh
Browse files

Fix blaslt group gemm dump


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>

Mutex group gemm
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>

do while group gemm
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>

Remove mutex
Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 772a941a
...@@ -887,10 +887,14 @@ static inline int getIntEnv(const char* name, int defval, int minval) { ...@@ -887,10 +887,14 @@ static inline int getIntEnv(const char* name, int defval, int minval) {
} //namespace } //namespace
static inline void CreateHipBlasLtHandle(hipblasLtHandle_t* handle) { static void CreateHipBlasLtHandle(hipblasLtHandle_t* handle) {
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(handle)); NVTE_CHECK_HIPBLASLT(hipblasLtCreate(handle));
} }
static void DestroyHipBlasLtHandle(hipblasLtHandle_t handle) {
NVTE_CHECK_HIPBLASLT(hipblasLtDestroy(handle));
}
using hipBlasLtHandleManager = detail::HandleManager<hipblasLtHandle_t, CreateHipBlasLtHandle>; using hipBlasLtHandleManager = detail::HandleManager<hipblasLtHandle_t, CreateHipBlasLtHandle>;
transformer_engine::DType get_transformer_engine_dtype_from_hipblaslt_dtype(const hipDataType t) { transformer_engine::DType get_transformer_engine_dtype_from_hipblaslt_dtype(const hipDataType t) {
...@@ -1240,36 +1244,202 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD, ...@@ -1240,36 +1244,202 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
} }
struct HipBlasLtUserArgsDeleter { struct HipBlasltHostUserArgs
void operator()(hipblaslt_ext::UserArguments* ptr) const noexcept { {
hipFree(ptr); HipBlasltHostUserArgs(): raw_(nullptr), event_(nullptr) {}
HipBlasltHostUserArgs(size_t size): raw_(nullptr), event_(nullptr)
{
hipblaslt_ext::UserArguments* raw_ptr = nullptr;
NVTE_CHECK_CUDA(hipHostMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
raw_ = raw_ptr;
hipEvent_t event = nullptr;
NVTE_CHECK_CUDA(hipEventCreateWithFlags(&event, hipEventBlockingSync));
event_ = event;
}
HipBlasltHostUserArgs(const HipBlasltHostUserArgs&) = delete;
HipBlasltHostUserArgs(HipBlasltHostUserArgs&& other)
{
raw_ = other.raw_;
event_ = other.event_;
other.raw_ = nullptr;
other.event_ = nullptr;
}
HipBlasltHostUserArgs& operator=(const HipBlasltHostUserArgs&) = delete;
HipBlasltHostUserArgs& operator=(HipBlasltHostUserArgs&& other)
{
if(this != &other)
{
free();
raw_ = other.raw_;
event_ = other.event_;
other.raw_ = nullptr;
other.event_ = nullptr;
}
return *this;
}
inline hipblaslt_ext::UserArguments* getArgs() const noexcept
{
return raw_;
}
inline hipEvent_t getEvent() const noexcept
{
return event_;
}
~HipBlasltHostUserArgs()
{
free();
}
private:
inline void free()
{
if(raw_)
{
if(event_)
{
NVTE_CHECK_CUDA(hipEventSynchronize(event_));
NVTE_CHECK_CUDA(hipEventDestroy(event_));
event_ = nullptr;
}
NVTE_CHECK_CUDA(hipFree(raw_));
raw_ = nullptr;
}
}
hipblaslt_ext::UserArguments* raw_;
hipEvent_t event_;
};
struct HipBlasltHostUserArgsBuffer
{
HipBlasltHostUserArgsBuffer() {}
HipBlasltHostUserArgsBuffer(size_t size)
{
for(int i = 0; i < 8; ++i)
{
buffer_[i] = std::move(HipBlasltHostUserArgs(size));
}
}
HipBlasltHostUserArgsBuffer(const HipBlasltHostUserArgsBuffer&) = delete;
HipBlasltHostUserArgsBuffer(HipBlasltHostUserArgsBuffer&& other) {
for(int i = 0; i < 8; ++i)
{
buffer_[i] = std::move(other.buffer_[i]);
}
index_ = other.index_;
}
HipBlasltHostUserArgsBuffer& operator=(const HipBlasltHostUserArgsBuffer&) = delete;
HipBlasltHostUserArgsBuffer& operator=(HipBlasltHostUserArgsBuffer&& other)
{
if(this != &other)
{
for(int i = 0; i < 8; ++i)
{
buffer_[i] = std::move(other.buffer_[i]);
}
index_ = other.index_;
}
return *this;
}
HipBlasltHostUserArgs& getHostUserArgs()
{
HipBlasltHostUserArgs& args = buffer_[index_];
if(index_ < 7)
{
++index_;
}
else
{
index_ = 0;
}
return args;
} }
private:
int index_ = 0;
HipBlasltHostUserArgs buffer_[8];
}; };
using HipBlasLtUserArgsPtr = std::unique_ptr<hipblaslt_ext::UserArguments, HipBlasLtUserArgsDeleter>; using HipBlasLtHostUserArgsBufferPtr = std::unique_ptr<HipBlasltHostUserArgsBuffer>;
inline HipBlasLtUserArgsPtr make_hipblaslt_user_args_ptr(size_t size, bool host) { HipBlasltHostUserArgsBuffer* getHipBlasLtHostUserArgsBuffer(size_t size)
{
static thread_local std::unordered_map<size_t, HipBlasLtHostUserArgsBufferPtr> user_args_cache;
auto size_it = user_args_cache.find(size);
if (size_it != user_args_cache.end()) {
return size_it->second.get();
}
else
{
HipBlasLtHostUserArgsBufferPtr user_args(new HipBlasltHostUserArgsBuffer(size));
HipBlasltHostUserArgsBuffer* raw_ptr = user_args.get();
user_args_cache[size] = std::move(user_args);
return raw_ptr;
}
}
struct HipBlasLtDeviceUserArgs {
HipBlasLtDeviceUserArgs(): stream_(nullptr), raw_(nullptr) {}
HipBlasLtDeviceUserArgs(hipStream_t stream, size_t size): stream_(stream), raw_(nullptr)
{
hipblaslt_ext::UserArguments* raw_ptr = nullptr; hipblaslt_ext::UserArguments* raw_ptr = nullptr;
if (host) {
NVTE_CHECK_CUDA(hipHostMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
} else {
NVTE_CHECK_CUDA(hipMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments))); NVTE_CHECK_CUDA(hipMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
raw_ = raw_ptr;
} }
return HipBlasLtUserArgsPtr(raw_ptr); HipBlasLtDeviceUserArgs(const HipBlasLtDeviceUserArgs&) = delete;
} HipBlasLtDeviceUserArgs(HipBlasLtDeviceUserArgs&& other)
{
stream_ = other.stream_;
raw_ = other.raw_;
other.stream_ = nullptr;
other.raw_ = nullptr;
}
HipBlasLtDeviceUserArgs& operator=(const HipBlasLtDeviceUserArgs&) = delete;
HipBlasLtDeviceUserArgs& operator=(HipBlasLtDeviceUserArgs&& other)
{
if(this != &other)
{
free();
stream_ = other.stream_;
raw_ = other.raw_;
other.stream_ = nullptr;
other.raw_ = nullptr;
}
return *this;
}
inline hipblaslt_ext::UserArguments* get() const noexcept
{
return raw_;
}
~HipBlasLtDeviceUserArgs()
{
free();
}
protected:
inline void free()
{
if(raw_)
{
NVTE_CHECK_CUDA(hipFreeAsync(raw_, stream_));
raw_ = nullptr;
}
}
hipStream_t stream_;
hipblaslt_ext::UserArguments* raw_;
};
inline hipblaslt_ext::UserArguments* get_hipblaslt_user_args(size_t size, bool host) { using HipBlasLtDeviceUserArgsPtr = std::unique_ptr<HipBlasLtDeviceUserArgs>;
thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> host_userargs_cache;
thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> device_userargs_cache; HipBlasLtDeviceUserArgs* getHipBlasLtDeviceUserArgs(hipStream_t stream, size_t size)
std::unordered_map<size_t, HipBlasLtUserArgsPtr>& user_args_cache = host ? host_userargs_cache : device_userargs_cache; {
static thread_local std::unordered_map<size_t, HipBlasLtDeviceUserArgsPtr> user_args_cache;
auto size_it = user_args_cache.find(size); auto size_it = user_args_cache.find(size);
if (size_it != user_args_cache.end()) { if (size_it != user_args_cache.end()) {
return size_it->second.get(); return size_it->second.get();
} }
else else
{ {
HipBlasLtUserArgsPtr user_args = make_hipblaslt_user_args_ptr(size, host); HipBlasLtDeviceUserArgsPtr user_args(new HipBlasLtDeviceUserArgs(stream, size));
hipblaslt_ext::UserArguments* raw_ptr = user_args.get(); HipBlasLtDeviceUserArgs* raw_ptr = user_args.get();
user_args_cache[size] = std::move(user_args); user_args_cache[size] = std::move(user_args);
return raw_ptr; return raw_ptr;
} }
...@@ -1284,18 +1454,19 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1284,18 +1454,19 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid. // Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams); NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
// hipblaslt_ext::UserArguments* userArgs = get_hipblaslt_user_args(m.size(), true); hipblasLtHandle_t handle = hipBlasLtHandleManager::Instance().GetHandle();
// hipblaslt_ext::UserArguments* d_userArgs = get_hipblaslt_user_args(m.size(), false);
hipblaslt_ext::UserArguments* userArgs; HipBlasLtDeviceUserArgs* device_user_args = getHipBlasLtDeviceUserArgs(stream, m.size());
NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments))); hipblaslt_ext::UserArguments* d_userArgs = device_user_args->get();
hipblasLtHandle_t handle = hipBlasLtHandleManager::Instance().GetHandle(); HipBlasltHostUserArgsBuffer* host_user_args_buffer = getHipBlasLtHostUserArgsBuffer(m.size());
HipBlasltHostUserArgs& host_user_args = host_user_args_buffer->getHostUserArgs();
hipblaslt_ext::UserArguments* userArgs = host_user_args.getArgs();
hipEvent_t event = host_user_args.getEvent();
const hipDataType A_type = get_hipblaslt_dtype(inputA[0]->data.dtype); const hipDataType A_type = get_hipblaslt_dtype(inputA[0]->data.dtype);
const hipDataType B_type = get_hipblaslt_dtype(inputB[0]->data.dtype); const hipDataType B_type = get_hipblaslt_dtype(inputB[0]->data.dtype);
const hipDataType D_type = get_hipblaslt_dtype(outputD[0]->data.dtype); const hipDataType D_type = get_hipblaslt_dtype(outputD[0]->data.dtype);
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
float one = 1.0; float one = 1.0;
...@@ -1312,16 +1483,14 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1312,16 +1483,14 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
computeType = HIPBLAS_COMPUTE_32I; computeType = HIPBLAS_COMPUTE_32I;
} }
hipblaslt_ext::GemmPreference gemmPref; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
gemmPref.setMaxWorkspaceBytes(workspaceSize); std::vector<hipblaslt_ext::GemmEpilogue> epilogue{hipblaslt_ext::GemmEpilogue()};
hipblaslt_ext::GroupedGemm groupedgemm(handle, transa, transb, A_type, B_type, D_type, D_type,
computeType);
std::vector<hipblaslt_ext::GemmEpilogue> epilogue{
hipblaslt_ext::
GemmEpilogue()}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
std::vector<hipblaslt_ext::GemmInputs> inputs(m.size()); std::vector<hipblaslt_ext::GemmInputs> inputs(m.size());
for (int i = 0; i < m.size(); i++) { for (int i = 0; i < m.size(); i++) {
assert(m[i] != 0);
assert(n[i] != 0);
assert(k[i] != 0);
assert(b[i] != 0);
inputs[i].a = inputA[i]->data.dptr; inputs[i].a = inputA[i]->data.dptr;
inputs[i].b = inputB[i]->data.dptr; inputs[i].b = inputB[i]->data.dptr;
inputs[i].c = outputD[i]->data.dptr; inputs[i].c = outputD[i]->data.dptr;
...@@ -1329,35 +1498,34 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1329,35 +1498,34 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
inputs[i].alpha = use_int8 ? static_cast<void*>(&int_one) : static_cast<void*>(&one); inputs[i].alpha = use_int8 ? static_cast<void*>(&int_one) : static_cast<void*>(&one);
inputs[i].beta = use_int8 ? static_cast<void*>(&int_beta) : static_cast<void*>(&beta); inputs[i].beta = use_int8 ? static_cast<void*>(&int_beta) : static_cast<void*>(&beta);
} }
// hipblaslt_ext::GemmEpilogue supports broadcasting
groupedgemm.setProblem(m, n, k, b, epilogue, inputs);
const int request_solutions = 1; const int request_solutions = 1;
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult; std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
NVTE_CHECK_HIPBLASLT(groupedgemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult));
NVTE_CHECK_CUDA(hipEventSynchronize(event));
hipblaslt_ext::GemmPreference gemmPref;
gemmPref.setMaxWorkspaceBytes(0);
hipblaslt_ext::GroupedGemm groupedgemm(handle, transa, transb, A_type, B_type, D_type, D_type, computeType);
// hipblaslt_ext::GemmEpilogue supports broadcasting
groupedgemm.setProblem(m, n, k, b, epilogue, inputs);
NVTE_CHECK_HIPBLASLT(groupedgemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult));
if (heuristicResult.empty()) { if (heuristicResult.empty()) {
std::cerr << "No valid solution found!" << std::endl; std::cerr << "No valid solution found!" << std::endl;
return; return;
} }
// Make sure to initialize everytime the algo changes // Make sure to initialize everytime the algo changes
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace)); NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, nullptr, true, stream));
// Get the default values from the grouepdgemm object // Get the default values from the grouepdgemm object
groupedgemm.getDefaultValueForDeviceUserArguments(userArgs); groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
// Copy them to device memory // Copy them to device memory
hipblaslt_ext::UserArguments* d_userArgs; NVTE_CHECK_CUDA(hipMemcpyAsync(d_userArgs, userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), hipMemcpyHostToDevice, stream));
NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
NVTE_CHECK_CUDA(hipMemcpyAsync(d_userArgs, userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments),
hipMemcpyHostToDevice, stream));
NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream)); NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream)); NVTE_CHECK_CUDA(hipEventRecord(event, stream));
NVTE_CHECK_CUDA(hipFreeAsync(userArgs, stream));
} }
#endif //USE_HIPBLASLT #endif //USE_HIPBLASLT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment