Commit a13c52ad authored by wenjh's avatar wenjh
Browse files

Fix user args core dump in mt

parent 3a5755b1
...@@ -1352,82 +1352,41 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD, ...@@ -1352,82 +1352,41 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
} }
struct HipBlasLtUserArgsDeleter {
class userArgsManager { void operator()(hipblaslt_ext::UserArguments* ptr) const noexcept {
public: hipFree(ptr);
userArgsManager() {}
~userArgsManager() {
// Release all userArgs when the manager is destroyed
for (auto& device_pair : userArgs_map_) {
hipFree(device_pair.second); // Only one userArgs per device
}
}
// Get a userArgs for the given device (creates if necessary)
hipblaslt_ext::UserArguments* get(int device_id, size_t size) {
std::lock_guard<std::mutex> lock(mutex_);
// Check if the userArgs for this device exists
auto device_it = userArgs_map_.find(device_id);
if (device_it != userArgs_map_.end()) {
return device_it->second;
} }
// Create a new userArgs for this device if it doesn't exist
hipblaslt_ext::UserArguments* userArgs;
NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, size * sizeof(hipblaslt_ext::UserArguments)));
// Store the userArgs in the map for this device
userArgs_map_[device_id] = userArgs;
return userArgs;
}
private:
std::unordered_map<int, hipblaslt_ext::UserArguments*>
userArgs_map_; // Map from device_id to hipblasHandle
std::mutex mutex_;
}; };
class d_userArgsManager { using HipBlasLtUserArgsPtr = std::unique_ptr<hipblaslt_ext::UserArguments, HipBlasLtUserArgsDeleter>;
public:
d_userArgsManager() {}
~d_userArgsManager() { inline HipBlasLtUserArgsPtr make_hipblaslt_user_args_ptr(size_t size, bool host) {
// Release all userArgs when the manager is destroyed hipblaslt_ext::UserArguments* raw_ptr = nullptr;
for (auto& device_pair : d_userArgs_map_) { if (host) {
hipFree(device_pair.second); // Only one userArgs per device NVTE_CHECK_CUDA(hipHostMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
} } else {
NVTE_CHECK_CUDA(hipMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
} }
return HipBlasLtUserArgsPtr(raw_ptr);
}
// Get a userArgs for the given device (creates if necessary) inline hipblaslt_ext::UserArguments* get_hipblaslt_user_args(size_t size, bool host) {
hipblaslt_ext::UserArguments* get(int device_id, size_t size) { thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> host_userargs_cache;
std::lock_guard<std::mutex> lock(mutex_); thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> device_userargs_cache;
std::unordered_map<size_t, HipBlasLtUserArgsPtr>& user_args_cache = host ? host_userargs_cache : device_userargs_cache;
// Check if the userArgs for this device exists auto size_it = user_args_cache.find(size);
auto device_it = d_userArgs_map_.find(device_id); if (size_it != user_args_cache.end()) {
if (device_it != d_userArgs_map_.end()) { return size_it->second.get();
return device_it->second;
} }
else
// Create a new userArgs for this device if it doesn't exist {
hipblaslt_ext::UserArguments* d_userArgs; HipBlasLtUserArgsPtr user_args = make_hipblaslt_user_args_ptr(size, host);
NVTE_CHECK_CUDA(hipMalloc(&d_userArgs, size * sizeof(hipblaslt_ext::UserArguments))); hipblaslt_ext::UserArguments* raw_ptr = user_args.get();
user_args_cache[size] = std::move(user_args);
// Store the userArgs in the map for this device return raw_ptr;
d_userArgs_map_[device_id] = d_userArgs;
return d_userArgs;
} }
}
private:
std::unordered_map<int, hipblaslt_ext::UserArguments*>
d_userArgs_map_; // Map from device_id to hipblasHandle
std::mutex mutex_;
};
// Define a static userArgs manager
static userArgsManager UAManager;
static d_userArgsManager d_UAManager;
void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB, void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB,
std::vector<Tensor*>& outputD, std::vector<int64_t>& m, std::vector<Tensor*>& outputD, std::vector<int64_t>& m,
...@@ -1438,10 +1397,8 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1438,10 +1397,8 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid. // Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams); NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
int device_id; hipblaslt_ext::UserArguments* userArgs = get_hipblaslt_user_args(m.size(), true);
hipGetDevice(&device_id); hipblaslt_ext::UserArguments* d_userArgs = get_hipblaslt_user_args(m.size(), false);
hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size());
// hipblaslt_ext::UserArguments* userArgs; // hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments))); // NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment