Commit 1261da47 authored by wenjh's avatar wenjh
Browse files

Complete fix blaslt group gemm dump


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 0a90777e
...@@ -991,6 +991,7 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, ...@@ -991,6 +991,7 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
bool use_split_accumulator, int math_sm_count, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_grouped_gemm); NVTE_API_CALL(nvte_grouped_gemm);
if(num_gemms == 0) { return; }
using namespace transformer_engine; using namespace transformer_engine;
std::vector<const Tensor*> inputA; std::vector<const Tensor*> inputA;
...@@ -1029,6 +1030,7 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, ...@@ -1029,6 +1030,7 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
n.push_back(B0); n.push_back(B0);
} }
} }
Tensor *wspace = convertNVTETensorCheck(workspace[0]); Tensor *wspace = convertNVTETensorCheck(workspace[0]);
if ((biasTensor[0]->data.dptr != nullptr) || (outputGelu[0]->data.dptr != nullptr)) { if ((biasTensor[0]->data.dptr != nullptr) || (outputGelu[0]->data.dptr != nullptr)) {
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <sstream> #include <sstream>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "../util/hip_runtime.h"
#endif #endif
#ifdef USE_ROCBLAS #ifdef USE_ROCBLAS
...@@ -1244,39 +1245,57 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD, ...@@ -1244,39 +1245,57 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
} }
struct HipBlasltHostUserArgs struct HipBlasltUserArgs
{ {
HipBlasltHostUserArgs(): raw_(nullptr), event_(nullptr) {} HipBlasltUserArgs(): stream_(nullptr), raw_(nullptr), event_(nullptr) {}
HipBlasltHostUserArgs(size_t size): raw_(nullptr), event_(nullptr) HipBlasltUserArgs(hipStream_t stream, size_t size, bool host): stream_(stream), raw_(nullptr), event_(nullptr)
{ {
hipblaslt_ext::UserArguments* raw_ptr = nullptr; hipblaslt_ext::UserArguments* raw_ptr = nullptr;
if(host) {
NVTE_CHECK_CUDA(hipHostMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments))); NVTE_CHECK_CUDA(hipHostMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
}
else {
NVTE_CHECK_CUDA(hipMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
}
raw_ = raw_ptr; raw_ = raw_ptr;
hipEvent_t event = nullptr; hipEvent_t event = nullptr;
if(host) {
NVTE_CHECK_CUDA(hipEventCreateWithFlags(&event, hipEventBlockingSync)); NVTE_CHECK_CUDA(hipEventCreateWithFlags(&event, hipEventBlockingSync));
}
else {
NVTE_CHECK_CUDA(hipEventCreateWithFlags(&event, hipEventDisableTiming));
}
event_ = event; event_ = event;
} }
HipBlasltHostUserArgs(const HipBlasltHostUserArgs&) = delete; HipBlasltUserArgs(const HipBlasltUserArgs&) = delete;
HipBlasltHostUserArgs(HipBlasltHostUserArgs&& other) HipBlasltUserArgs(HipBlasltUserArgs&& other)
{ {
stream_ = other.stream_;
raw_ = other.raw_; raw_ = other.raw_;
event_ = other.event_; event_ = other.event_;
other.stream_ = nullptr;
other.raw_ = nullptr; other.raw_ = nullptr;
other.event_ = nullptr; other.event_ = nullptr;
} }
HipBlasltHostUserArgs& operator=(const HipBlasltHostUserArgs&) = delete; HipBlasltUserArgs& operator=(const HipBlasltUserArgs&) = delete;
HipBlasltHostUserArgs& operator=(HipBlasltHostUserArgs&& other) HipBlasltUserArgs& operator=(HipBlasltUserArgs&& other)
{ {
if(this != &other) if(this != &other)
{ {
free(); free();
stream_ = other.stream_;
raw_ = other.raw_; raw_ = other.raw_;
event_ = other.event_; event_ = other.event_;
other.stream_ = nullptr;
other.raw_ = nullptr; other.raw_ = nullptr;
other.event_ = nullptr; other.event_ = nullptr;
} }
return *this; return *this;
} }
inline hipStream_t getStream() const noexcept
{
return stream_;
}
inline hipblaslt_ext::UserArguments* getArgs() const noexcept inline hipblaslt_ext::UserArguments* getArgs() const noexcept
{ {
return raw_; return raw_;
...@@ -1285,12 +1304,16 @@ struct HipBlasltHostUserArgs ...@@ -1285,12 +1304,16 @@ struct HipBlasltHostUserArgs
{ {
return event_; return event_;
} }
~HipBlasltHostUserArgs() inline void setStream(hipStream_t stream) noexcept
{
stream_ = stream;
}
~HipBlasltUserArgs()
{ {
free(); free();
} }
private: private:
inline void free() void free()
{ {
if(raw_) if(raw_)
{ {
...@@ -1304,34 +1327,35 @@ private: ...@@ -1304,34 +1327,35 @@ private:
raw_ = nullptr; raw_ = nullptr;
} }
} }
hipStream_t stream_;
hipblaslt_ext::UserArguments* raw_; hipblaslt_ext::UserArguments* raw_;
hipEvent_t event_; hipEvent_t event_;
}; };
struct HipBlasltHostUserArgsBuffer struct HipBlasltUserArgsBuffer
{ {
HipBlasltHostUserArgsBuffer() {} HipBlasltUserArgsBuffer() {}
HipBlasltHostUserArgsBuffer(size_t size) HipBlasltUserArgsBuffer(hipStream_t stream, size_t size, bool host)
{ {
for(int i = 0; i < 8; ++i) for(int i = 0; i < 4; ++i)
{ {
buffer_[i] = std::move(HipBlasltHostUserArgs(size)); buffer_[i] = std::move(HipBlasltUserArgs(stream, size, host));
} }
} }
HipBlasltHostUserArgsBuffer(const HipBlasltHostUserArgsBuffer&) = delete; HipBlasltUserArgsBuffer(const HipBlasltUserArgsBuffer&) = delete;
HipBlasltHostUserArgsBuffer(HipBlasltHostUserArgsBuffer&& other) { HipBlasltUserArgsBuffer(HipBlasltUserArgsBuffer&& other) {
for(int i = 0; i < 8; ++i) for(int i = 0; i < 4; ++i)
{ {
buffer_[i] = std::move(other.buffer_[i]); buffer_[i] = std::move(other.buffer_[i]);
} }
index_ = other.index_; index_ = other.index_;
} }
HipBlasltHostUserArgsBuffer& operator=(const HipBlasltHostUserArgsBuffer&) = delete; HipBlasltUserArgsBuffer& operator=(const HipBlasltUserArgsBuffer&) = delete;
HipBlasltHostUserArgsBuffer& operator=(HipBlasltHostUserArgsBuffer&& other) HipBlasltUserArgsBuffer& operator=(HipBlasltUserArgsBuffer&& other)
{ {
if(this != &other) if(this != &other)
{ {
for(int i = 0; i < 8; ++i) for(int i = 0; i < 4; ++i)
{ {
buffer_[i] = std::move(other.buffer_[i]); buffer_[i] = std::move(other.buffer_[i]);
} }
...@@ -1339,11 +1363,11 @@ struct HipBlasltHostUserArgsBuffer ...@@ -1339,11 +1363,11 @@ struct HipBlasltHostUserArgsBuffer
} }
return *this; return *this;
} }
HipBlasltHostUserArgs& getHostUserArgs() HipBlasltUserArgs& getUserArgs()
{ {
HipBlasltHostUserArgs& args = buffer_[index_]; HipBlasltUserArgs& args = buffer_[index_];
if(index_ < 7) if(index_ < 3)
{ {
++index_; ++index_;
} }
...@@ -1356,94 +1380,48 @@ struct HipBlasltHostUserArgsBuffer ...@@ -1356,94 +1380,48 @@ struct HipBlasltHostUserArgsBuffer
} }
private: private:
int index_ = 0; int index_ = 0;
HipBlasltHostUserArgs buffer_[8]; HipBlasltUserArgs buffer_[4];
}; };
using HipBlasLtHostUserArgsBufferPtr = std::unique_ptr<HipBlasltHostUserArgsBuffer>; // using HipBlasltUserArgsBufferPtr = std::unique_ptr<HipBlasltUserArgsBuffer>;
HipBlasltHostUserArgsBuffer* getHipBlasLtHostUserArgsBuffer(size_t size) struct HipBlasltUserArgsCache
{ {
static thread_local std::unordered_map<size_t, HipBlasLtHostUserArgsBufferPtr> user_args_cache; HipBlasltUserArgsCache() {}
auto size_it = user_args_cache.find(size); HipBlasltUserArgsCache(const HipBlasltUserArgsCache&) = delete;
if (size_it != user_args_cache.end()) { HipBlasltUserArgsBuffer& operator=(const HipBlasltUserArgsBuffer&) = delete;
return size_it->second.get(); HipBlasltUserArgsBuffer& getBuffer(hipStream_t stream, size_t size, bool host)
}
else
{ {
HipBlasLtHostUserArgsBufferPtr user_args(new HipBlasltHostUserArgsBuffer(size)); std::unordered_map<size_t, HipBlasltUserArgsBuffer>& buffers = host ? host_buffers_: device_buffers_;
HipBlasltHostUserArgsBuffer* raw_ptr = user_args.get(); auto size_it = buffers.find(size);
user_args_cache[size] = std::move(user_args); if (size_it != buffers.end()) {
return raw_ptr; return size_it->second;
} }
} else
struct HipBlasLtDeviceUserArgs {
HipBlasLtDeviceUserArgs(): stream_(nullptr), raw_(nullptr) {}
HipBlasLtDeviceUserArgs(hipStream_t stream, size_t size): stream_(stream), raw_(nullptr)
{
hipblaslt_ext::UserArguments* raw_ptr = nullptr;
NVTE_CHECK_CUDA(hipMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
raw_ = raw_ptr;
}
HipBlasLtDeviceUserArgs(const HipBlasLtDeviceUserArgs&) = delete;
HipBlasLtDeviceUserArgs(HipBlasLtDeviceUserArgs&& other)
{
stream_ = other.stream_;
raw_ = other.raw_;
other.stream_ = nullptr;
other.raw_ = nullptr;
}
HipBlasLtDeviceUserArgs& operator=(const HipBlasLtDeviceUserArgs&) = delete;
HipBlasLtDeviceUserArgs& operator=(HipBlasLtDeviceUserArgs&& other)
{
if(this != &other)
{
free();
stream_ = other.stream_;
raw_ = other.raw_;
other.stream_ = nullptr;
other.raw_ = nullptr;
}
return *this;
}
inline hipblaslt_ext::UserArguments* get() const noexcept
{
return raw_;
}
~HipBlasLtDeviceUserArgs()
{
free();
}
protected:
inline void free()
{
if(raw_)
{ {
NVTE_CHECK_CUDA(hipFreeAsync(raw_, stream_)); return buffers.emplace(size, HipBlasltUserArgsBuffer{stream, size, host}).first->second;
raw_ = nullptr;
} }
} }
hipStream_t stream_; private:
hipblaslt_ext::UserArguments* raw_; std::unordered_map<size_t, HipBlasltUserArgsBuffer> host_buffers_;
std::unordered_map<size_t, HipBlasltUserArgsBuffer> device_buffers_;
}; };
using HipBlasLtDeviceUserArgsPtr = std::unique_ptr<HipBlasLtDeviceUserArgs>; struct HipBlasltUserArgsCacheManager {
static HipBlasltUserArgsCacheManager& instance() {
HipBlasLtDeviceUserArgs* getHipBlasLtDeviceUserArgs(hipStream_t stream, size_t size) static thread_local HipBlasltUserArgsCacheManager instance_;
{ return instance_;
static thread_local std::unordered_map<size_t, HipBlasLtDeviceUserArgsPtr> user_args_cache;
auto size_it = user_args_cache.find(size);
if (size_it != user_args_cache.end()) {
return size_it->second.get();
} }
else
{ HipBlasltUserArgsCache& getCache() {
HipBlasLtDeviceUserArgsPtr user_args(new HipBlasLtDeviceUserArgs(stream, size)); const int device_id = cuda::current_device();
HipBlasLtDeviceUserArgs* raw_ptr = user_args.get(); NVTE_CHECK(0 <= device_id && device_id < caches_.size(), "invalid CUDA device ID");
user_args_cache[size] = std::move(user_args); return caches_[device_id];
return raw_ptr;
} }
} private:
HipBlasltUserArgsCacheManager() : caches_(cuda::num_devices()) {}
std::vector<HipBlasltUserArgsCache> caches_;
};
void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB, void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB,
std::vector<Tensor*>& outputD, std::vector<int64_t>& m, std::vector<Tensor*>& outputD, std::vector<int64_t>& m,
...@@ -1456,13 +1434,14 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1456,13 +1434,14 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
hipblasLtHandle_t handle = hipBlasLtHandleManager::Instance().GetHandle(); hipblasLtHandle_t handle = hipBlasLtHandleManager::Instance().GetHandle();
HipBlasLtDeviceUserArgs* device_user_args = getHipBlasLtDeviceUserArgs(stream, m.size()); HipBlasltUserArgs& device_user_args = HipBlasltUserArgsCacheManager::instance().getCache().getBuffer(stream, m.size(), false).getUserArgs();
hipblaslt_ext::UserArguments* d_userArgs = device_user_args->get(); hipblaslt_ext::UserArguments* device_args = device_user_args.getArgs();
hipEvent_t device_event = device_user_args.getEvent();
hipStream_t device_stream = device_user_args.getStream();
HipBlasltHostUserArgsBuffer* host_user_args_buffer = getHipBlasLtHostUserArgsBuffer(m.size()); HipBlasltUserArgs& host_user_args = HipBlasltUserArgsCacheManager::instance().getCache().getBuffer(stream, m.size(), true).getUserArgs();
HipBlasltHostUserArgs& host_user_args = host_user_args_buffer->getHostUserArgs(); hipblaslt_ext::UserArguments* host_args = host_user_args.getArgs();
hipblaslt_ext::UserArguments* userArgs = host_user_args.getArgs(); hipEvent_t host_event = host_user_args.getEvent();
hipEvent_t event = host_user_args.getEvent();
const hipDataType A_type = get_hipblaslt_dtype(inputA[0]->data.dtype); const hipDataType A_type = get_hipblaslt_dtype(inputA[0]->data.dtype);
const hipDataType B_type = get_hipblaslt_dtype(inputB[0]->data.dtype); const hipDataType B_type = get_hipblaslt_dtype(inputB[0]->data.dtype);
...@@ -1502,8 +1481,6 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1502,8 +1481,6 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
const int request_solutions = 1; const int request_solutions = 1;
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult; std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
NVTE_CHECK_CUDA(hipEventSynchronize(event));
hipblaslt_ext::GemmPreference gemmPref; hipblaslt_ext::GemmPreference gemmPref;
gemmPref.setMaxWorkspaceBytes(0); gemmPref.setMaxWorkspaceBytes(0);
...@@ -1519,13 +1496,19 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1519,13 +1496,19 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Make sure to initialize everytime the algo changes // Make sure to initialize everytime the algo changes
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, nullptr, true, stream)); NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, nullptr, true, stream));
NVTE_CHECK_CUDA(hipEventSynchronize(host_event));
// Get the default values from the grouepdgemm object // Get the default values from the grouepdgemm object
groupedgemm.getDefaultValueForDeviceUserArguments(userArgs); groupedgemm.getDefaultValueForDeviceUserArguments(host_args);
if(stream != device_stream) {
NVTE_CHECK_CUDA(hipStreamWaitEvent(stream, device_event, 0));
}
// Copy them to device memory // Copy them to device memory
NVTE_CHECK_CUDA(hipMemcpyAsync(d_userArgs, userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), hipMemcpyHostToDevice, stream)); NVTE_CHECK_CUDA(hipMemcpyAsync(device_args, host_args, m.size() * sizeof(hipblaslt_ext::UserArguments), hipMemcpyHostToDevice, stream));
NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream)); NVTE_CHECK_CUDA(hipEventRecord(host_event, stream));
NVTE_CHECK_HIPBLASLT(groupedgemm.run(device_args, stream));
NVTE_CHECK_CUDA(hipEventRecord(event, stream)); device_user_args.setStream(stream);
NVTE_CHECK_CUDA(hipEventRecord(device_event, stream));
} }
#endif //USE_HIPBLASLT #endif //USE_HIPBLASLT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment