Commit 3cdceb87 authored by wenjh's avatar wenjh
Browse files

Delete tmpArgs in groupedgemm


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 1f97aebb
......@@ -1478,45 +1478,9 @@ private:
std::mutex mutex_;
};
class tmp_userArgsManager {
public:
tmp_userArgsManager() {}
~tmp_userArgsManager() {
// Release all userArgs when the manager is destroyed
for (auto& device_pair : tmp_userArgs_map_) {
hipFree(device_pair.second); // Only one userArgs per device
}
}
// Get a userArgs for the given device (creates if necessary)
void* get(int device_id, size_t size) {
std::lock_guard<std::mutex> lock(mutex_);
// Check if the userArgs for this device exists
auto device_it = tmp_userArgs_map_.find(device_id);
if (device_it != tmp_userArgs_map_.end()) {
return device_it->second;
}
// Create a new userArgs for this device if it doesn't exist
void* tmp_userArgs;
NVTE_CHECK_CUDA(hipHostMalloc(&tmp_userArgs, size));
// Store the userArgs in the map for this device
tmp_userArgs_map_[device_id] = tmp_userArgs;
return tmp_userArgs;
}
private:
std::unordered_map<int, void*> tmp_userArgs_map_; // Map from device_id to hipblasHandle
std::mutex mutex_;
};
// Define a static userArgs manager
static userArgsManager UAManager;
static d_userArgsManager d_UAManager;
static tmp_userArgsManager tmp_UAManager;
void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB, std::vector<Tensor*>& outputD,
std::vector<int64_t>& m, std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b, hipblasOperation_t transa, hipblasOperation_t transb,
......@@ -1529,7 +1493,6 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
hipGetDevice(&device_id);
hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size());
void* tmp_userArgs = tmp_UAManager.get(device_id, 32768);
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
......@@ -1573,8 +1536,7 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
B_type,
D_type,
D_type,
computeType,
tmp_userArgs);
computeType);
std::vector<hipblaslt_ext::GemmEpilogue> epilogue{
hipblaslt_ext::
......@@ -1605,6 +1567,7 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Make sure to initialize everytime the algo changes
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
// Get the default values from the grouepdgemm object
groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
// Copy them to device memory
......@@ -1614,6 +1577,7 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
userArgs,
m.size() * sizeof(hipblaslt_ext::UserArguments),
hipMemcpyHostToDevice));
NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment