Commit a13c52ad authored by wenjh's avatar wenjh
Browse files

Fix user args core dump in mt

parent 3a5755b1
......@@ -1352,82 +1352,41 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
}
class userArgsManager {
public:
userArgsManager() {}
~userArgsManager() {
// Release all userArgs when the manager is destroyed
for (auto& device_pair : userArgs_map_) {
hipFree(device_pair.second); // Only one userArgs per device
}
}
// Get a userArgs for the given device (creates if necessary)
hipblaslt_ext::UserArguments* get(int device_id, size_t size) {
std::lock_guard<std::mutex> lock(mutex_);
// Check if the userArgs for this device exists
auto device_it = userArgs_map_.find(device_id);
if (device_it != userArgs_map_.end()) {
return device_it->second;
struct HipBlasLtUserArgsDeleter {
void operator()(hipblaslt_ext::UserArguments* ptr) const noexcept {
hipFree(ptr);
}
// Create a new userArgs for this device if it doesn't exist
hipblaslt_ext::UserArguments* userArgs;
NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, size * sizeof(hipblaslt_ext::UserArguments)));
// Store the userArgs in the map for this device
userArgs_map_[device_id] = userArgs;
return userArgs;
}
private:
std::unordered_map<int, hipblaslt_ext::UserArguments*>
userArgs_map_; // Map from device_id to hipblasHandle
std::mutex mutex_;
};
class d_userArgsManager {
public:
d_userArgsManager() {}
using HipBlasLtUserArgsPtr = std::unique_ptr<hipblaslt_ext::UserArguments, HipBlasLtUserArgsDeleter>;
~d_userArgsManager() {
// Release all userArgs when the manager is destroyed
for (auto& device_pair : d_userArgs_map_) {
hipFree(device_pair.second); // Only one userArgs per device
}
inline HipBlasLtUserArgsPtr make_hipblaslt_user_args_ptr(size_t size, bool host) {
hipblaslt_ext::UserArguments* raw_ptr = nullptr;
if (host) {
NVTE_CHECK_CUDA(hipHostMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
} else {
NVTE_CHECK_CUDA(hipMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
}
return HipBlasLtUserArgsPtr(raw_ptr);
}
// Get a userArgs for the given device (creates if necessary)
hipblaslt_ext::UserArguments* get(int device_id, size_t size) {
std::lock_guard<std::mutex> lock(mutex_);
// Check if the userArgs for this device exists
auto device_it = d_userArgs_map_.find(device_id);
if (device_it != d_userArgs_map_.end()) {
return device_it->second;
inline hipblaslt_ext::UserArguments* get_hipblaslt_user_args(size_t size, bool host) {
thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> host_userargs_cache;
thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> device_userargs_cache;
std::unordered_map<size_t, HipBlasLtUserArgsPtr>& user_args_cache = host ? host_userargs_cache : device_userargs_cache;
auto size_it = user_args_cache.find(size);
if (size_it != user_args_cache.end()) {
return size_it->second.get();
}
// Create a new userArgs for this device if it doesn't exist
hipblaslt_ext::UserArguments* d_userArgs;
NVTE_CHECK_CUDA(hipMalloc(&d_userArgs, size * sizeof(hipblaslt_ext::UserArguments)));
// Store the userArgs in the map for this device
d_userArgs_map_[device_id] = d_userArgs;
return d_userArgs;
else
{
HipBlasLtUserArgsPtr user_args = make_hipblaslt_user_args_ptr(size, host);
hipblaslt_ext::UserArguments* raw_ptr = user_args.get();
user_args_cache[size] = std::move(user_args);
return raw_ptr;
}
}
private:
std::unordered_map<int, hipblaslt_ext::UserArguments*>
d_userArgs_map_; // Map from device_id to hipblasHandle
std::mutex mutex_;
};
// Define a static userArgs manager
static userArgsManager UAManager;
static d_userArgsManager d_UAManager;
void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB,
std::vector<Tensor*>& outputD, std::vector<int64_t>& m,
......@@ -1438,10 +1397,8 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
int device_id;
hipGetDevice(&device_id);
hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size());
hipblaslt_ext::UserArguments* userArgs = get_hipblaslt_user_args(m.size(), true);
hipblaslt_ext::UserArguments* d_userArgs = get_hipblaslt_user_args(m.size(), false);
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment