"examples/vscode:/vscode.git/clone" did not exist on "c2a643d50b91ce885f5f1b1bd144651bc86dff22"
Commit ee3c1c81 authored by wenjh's avatar wenjh
Browse files

Merge develop_v2.9 to release_v2.9


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents cd6bd507 e698a0a7
......@@ -432,7 +432,7 @@ def test_fuser_ops_with_userbuffers(
command = []
if tex.ubuf_built_with_mpi():
python_exe = pathlib.Path(sys.executable).resolve()
command.extend(("mpirun", "-np", str(world_size), "--oversubscribe", "--quiet", python_exe))
command.extend(("mpirun", "-np", str(world_size), "--oversubscribe", "--allow-run-as-root", python_exe))
else:
command.extend(("torchrun", f"--nproc_per_node={world_size}"))
......
......@@ -340,6 +340,7 @@ else()
fused_attn/flash_attn.cu
fused_attn/context_parallel.cu
fused_attn/kv_cache.cu
fused_attn/utils.cu
gemm/cublaslt_gemm.cu
gemm/hipblas_gemm.cu
normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
......
......@@ -44,7 +44,17 @@ __device__ __forceinline__ uint32_t bytewise_less_than(uint32_t a, uint32_t b) {
// Bitwise logical op to get answer in MSBs
// Equivalent logic: result = (a == b) ? !result : b
#ifdef __HIP_PLATFORM_AMD__
result = (a == b) ? !result : b;
// Use HIP vector types for byte-wise parallel comparison
union { uint32_t u32; uint8_t bytes[4]; } a_union, b_union;
a_union.u32 = a;
b_union.u32 = b;
uint32_t mask = 0;
mask |= (a_union.bytes[0] < b_union.bytes[0]) ? 0x80000000U : 0;
mask |= (a_union.bytes[1] < b_union.bytes[1]) ? 0x00800000U : 0;
mask |= (a_union.bytes[2] < b_union.bytes[2]) ? 0x00008000U : 0;
mask |= (a_union.bytes[3] < b_union.bytes[3]) ? 0x00000080U : 0;
result = mask;
#else
asm("lop3.b32 %0, %1, %2, %3, 0x4D;\n\t" : "=r"(result) : "r"(a), "r"(b), "r"(result));
#endif
......
......@@ -14,7 +14,7 @@
namespace transformer_engine {
namespace fused_attn {
#ifndef __HIP_PLATFORM_AMD__
using namespace transformer_engine;
// get matrix strides based on matrix type
......@@ -610,7 +610,7 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
return hout;
}
#endif
__global__ void extract_seed_and_offset(int64_t *rng_state_ptr, bool captured, int64_t *seed_ptr,
uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val,
uint32_t offset_intragraph) {
......
......@@ -7,10 +7,11 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#ifndef __HIP_PLATFORM_AMD__
#include <cudnn.h>
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#endif
#include <cstdint>
#include <mutex>
......@@ -19,7 +20,7 @@
namespace transformer_engine {
namespace fused_attn {
#ifndef __HIP_PLATFORM_AMD__
using namespace transformer_engine;
enum NVTE_QKV_Matrix {
......@@ -187,7 +188,7 @@ void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q
cudaStream_t stream);
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);
#endif
} // namespace fused_attn
} // namespace transformer_engine
......
......@@ -1367,7 +1367,8 @@ void nvte_grouped_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
bool use_split_accumulator, int math_sm_count,
cudaStream_t stream) {
using namespace transformer_engine;
if(num_gemms == 0) { return; }
std::vector<const Tensor*> inputA;
std::vector<const Tensor*> inputB;
std::vector<Tensor*> outputD;
......
......@@ -20,6 +20,7 @@
#include <sstream>
#include <unordered_map>
#include <vector>
#include "../util/hip_runtime.h"
#endif
#ifdef USE_ROCBLAS
......@@ -887,11 +888,17 @@ static inline int getIntEnv(const char* name, int defval, int minval) {
} //namespace
static inline void CreateHipBlasLtHandle(hipblasLtHandle_t* handle) {
static void CreateHipBlasLtHandle(hipblasLtHandle_t* handle) {
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(handle));
}
using hipBlasLtHandleManager = detail::HandleManager<hipblasLtHandle_t, CreateHipBlasLtHandle>;
static void DestroyHipBlasLtHandle(hipblasLtHandle_t handle) {
if(handle != nullptr)
NVTE_CHECK_HIPBLASLT(hipblasLtDestroy(handle));
}
}
using hipBlasLtHandleManager = detail::HandleManager<hipblasLtHandle_t, CreateHipBlasLtHandle, DestroyHipBlasLtHandle>;
transformer_engine::DType get_transformer_engine_dtype_from_hipblaslt_dtype(const hipDataType t) {
using namespace transformer_engine;
......@@ -1240,40 +1247,183 @@ void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
}
struct HipBlasLtUserArgsDeleter {
void operator()(hipblaslt_ext::UserArguments* ptr) const noexcept {
hipFree(ptr);
struct HipBlasltUserArgs
{
HipBlasltUserArgs(): stream_(nullptr), raw_(nullptr), event_(nullptr) {}
HipBlasltUserArgs(hipStream_t stream, size_t size, bool host): stream_(stream), raw_(nullptr), event_(nullptr)
{
hipblaslt_ext::UserArguments* raw_ptr = nullptr;
if(host) {
NVTE_CHECK_CUDA(hipHostMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
}
else {
NVTE_CHECK_CUDA(hipMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
}
raw_ = raw_ptr;
hipEvent_t event = nullptr;
if(host) {
NVTE_CHECK_CUDA(hipEventCreateWithFlags(&event, hipEventBlockingSync));
}
else {
NVTE_CHECK_CUDA(hipEventCreateWithFlags(&event, hipEventDisableTiming));
}
event_ = event;
}
HipBlasltUserArgs(const HipBlasltUserArgs&) = delete;
HipBlasltUserArgs(HipBlasltUserArgs&& other)
{
stream_ = other.stream_;
raw_ = other.raw_;
event_ = other.event_;
other.stream_ = nullptr;
other.raw_ = nullptr;
other.event_ = nullptr;
}
HipBlasltUserArgs& operator=(const HipBlasltUserArgs&) = delete;
HipBlasltUserArgs& operator=(HipBlasltUserArgs&& other)
{
if(this != &other)
{
free();
stream_ = other.stream_;
raw_ = other.raw_;
event_ = other.event_;
other.stream_ = nullptr;
other.raw_ = nullptr;
other.event_ = nullptr;
}
return *this;
}
inline hipStream_t getStream() const noexcept
{
return stream_;
}
inline hipblaslt_ext::UserArguments* getArgs() const noexcept
{
return raw_;
}
inline hipEvent_t getEvent() const noexcept
{
return event_;
}
inline void setStream(hipStream_t stream) noexcept
{
stream_ = stream;
}
~HipBlasltUserArgs()
{
free();
}
private:
void free()
{
if(raw_)
{
if(event_)
{
NVTE_CHECK_CUDA(hipEventSynchronize(event_));
NVTE_CHECK_CUDA(hipEventDestroy(event_));
event_ = nullptr;
}
NVTE_CHECK_CUDA(hipFree(raw_));
raw_ = nullptr;
}
}
hipStream_t stream_;
hipblaslt_ext::UserArguments* raw_;
hipEvent_t event_;
};
using HipBlasLtUserArgsPtr = std::unique_ptr<hipblaslt_ext::UserArguments, HipBlasLtUserArgsDeleter>;
inline HipBlasLtUserArgsPtr make_hipblaslt_user_args_ptr(size_t size, bool host) {
hipblaslt_ext::UserArguments* raw_ptr = nullptr;
if (host) {
NVTE_CHECK_CUDA(hipHostMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
} else {
NVTE_CHECK_CUDA(hipMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
struct HipBlasltUserArgsBuffer
{
HipBlasltUserArgsBuffer() {}
HipBlasltUserArgsBuffer(hipStream_t stream, size_t size, bool host)
{
for(int i = 0; i < 4; ++i)
{
buffer_[i] = std::move(HipBlasltUserArgs(stream, size, host));
}
}
return HipBlasLtUserArgsPtr(raw_ptr);
}
HipBlasltUserArgsBuffer(const HipBlasltUserArgsBuffer&) = delete;
HipBlasltUserArgsBuffer(HipBlasltUserArgsBuffer&& other) {
for(int i = 0; i < 4; ++i)
{
buffer_[i] = std::move(other.buffer_[i]);
}
index_ = other.index_;
}
HipBlasltUserArgsBuffer& operator=(const HipBlasltUserArgsBuffer&) = delete;
HipBlasltUserArgsBuffer& operator=(HipBlasltUserArgsBuffer&& other)
{
if(this != &other)
{
for(int i = 0; i < 4; ++i)
{
buffer_[i] = std::move(other.buffer_[i]);
}
index_ = other.index_;
}
return *this;
}
HipBlasltUserArgs& getUserArgs()
{
HipBlasltUserArgs& args = buffer_[index_];
inline hipblaslt_ext::UserArguments* get_hipblaslt_user_args(size_t size, bool host) {
thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> host_userargs_cache;
thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> device_userargs_cache;
std::unordered_map<size_t, HipBlasLtUserArgsPtr>& user_args_cache = host ? host_userargs_cache : device_userargs_cache;
auto size_it = user_args_cache.find(size);
if (size_it != user_args_cache.end()) {
return size_it->second.get();
if(index_ < 3)
{
++index_;
}
else
{
index_ = 0;
}
return args;
}
else
private:
int index_ = 0;
HipBlasltUserArgs buffer_[4];
};
// using HipBlasltUserArgsBufferPtr = std::unique_ptr<HipBlasltUserArgsBuffer>;
struct HipBlasltUserArgsCache
{
HipBlasltUserArgsCache() {}
HipBlasltUserArgsCache(const HipBlasltUserArgsCache&) = delete;
HipBlasltUserArgsBuffer& operator=(const HipBlasltUserArgsBuffer&) = delete;
HipBlasltUserArgsBuffer& getBuffer(hipStream_t stream, size_t size, bool host)
{
HipBlasLtUserArgsPtr user_args = make_hipblaslt_user_args_ptr(size, host);
hipblaslt_ext::UserArguments* raw_ptr = user_args.get();
user_args_cache[size] = std::move(user_args);
return raw_ptr;
std::unordered_map<size_t, HipBlasltUserArgsBuffer>& buffers = host ? host_buffers_: device_buffers_;
auto size_it = buffers.find(size);
if (size_it != buffers.end()) {
return size_it->second;
}
else
{
return buffers.emplace(size, HipBlasltUserArgsBuffer{stream, size, host}).first->second;
}
}
}
private:
std::unordered_map<size_t, HipBlasltUserArgsBuffer> host_buffers_;
std::unordered_map<size_t, HipBlasltUserArgsBuffer> device_buffers_;
};
struct HipBlasltUserArgsCacheManager {
static HipBlasltUserArgsCacheManager& instance() {
static thread_local HipBlasltUserArgsCacheManager instance_;
return instance_;
}
HipBlasltUserArgsCache& getCache() {
const int device_id = cuda::current_device();
NVTE_CHECK(0 <= device_id && device_id < caches_.size(), "invalid CUDA device ID");
return caches_[device_id];
}
private:
HipBlasltUserArgsCacheManager() : caches_(cuda::num_devices()) {}
std::vector<HipBlasltUserArgsCache> caches_;
};
void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB,
......@@ -1285,18 +1435,20 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
hipblaslt_ext::UserArguments* userArgs = get_hipblaslt_user_args(m.size(), true);
hipblaslt_ext::UserArguments* d_userArgs = get_hipblaslt_user_args(m.size(), false);
hipblasLtHandle_t handle = hipBlasLtHandleManager::Instance().GetHandle();
// hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
HipBlasltUserArgs& device_user_args = HipBlasltUserArgsCacheManager::instance().getCache().getBuffer(stream, m.size(), false).getUserArgs();
hipblaslt_ext::UserArguments* device_args = device_user_args.getArgs();
hipEvent_t device_event = device_user_args.getEvent();
hipStream_t device_stream = device_user_args.getStream();
hipblasLtHandle_t handle = hipBlasLtHandleManager::Instance().GetHandle();
HipBlasltUserArgs& host_user_args = HipBlasltUserArgsCacheManager::instance().getCache().getBuffer(stream, m.size(), true).getUserArgs();
hipblaslt_ext::UserArguments* host_args = host_user_args.getArgs();
hipEvent_t host_event = host_user_args.getEvent();
const hipDataType A_type = get_hipblaslt_dtype(inputA[0]->data.dtype);
const hipDataType B_type = get_hipblaslt_dtype(inputB[0]->data.dtype);
const hipDataType D_type = get_hipblaslt_dtype(outputD[0]->data.dtype);
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
float one = 1.0;
......@@ -1313,16 +1465,14 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
computeType = HIPBLAS_COMPUTE_32I;
}
hipblaslt_ext::GemmPreference gemmPref;
gemmPref.setMaxWorkspaceBytes(workspaceSize);
hipblaslt_ext::GroupedGemm groupedgemm(handle, transa, transb, A_type, B_type, D_type, D_type,
computeType);
std::vector<hipblaslt_ext::GemmEpilogue> epilogue{
hipblaslt_ext::
GemmEpilogue()}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
// No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
std::vector<hipblaslt_ext::GemmEpilogue> epilogue{hipblaslt_ext::GemmEpilogue()};
std::vector<hipblaslt_ext::GemmInputs> inputs(m.size());
for (int i = 0; i < m.size(); i++) {
assert(m[i] != 0);
assert(n[i] != 0);
assert(k[i] != 0);
assert(b[i] != 0);
inputs[i].a = inputA[i]->data.dptr;
inputs[i].b = inputB[i]->data.dptr;
inputs[i].c = outputD[i]->data.dptr;
......@@ -1330,35 +1480,38 @@ void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
inputs[i].alpha = use_int8 ? static_cast<void*>(&int_one) : static_cast<void*>(&one);
inputs[i].beta = use_int8 ? static_cast<void*>(&int_beta) : static_cast<void*>(&beta);
}
// hipblaslt_ext::GemmEpilogue supports broadcasting
groupedgemm.setProblem(m, n, k, b, epilogue, inputs);
const int request_solutions = 1;
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
NVTE_CHECK_HIPBLASLT(groupedgemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult));
hipblaslt_ext::GemmPreference gemmPref;
gemmPref.setMaxWorkspaceBytes(0);
hipblaslt_ext::GroupedGemm groupedgemm(handle, transa, transb, A_type, B_type, D_type, D_type, computeType);
// hipblaslt_ext::GemmEpilogue supports broadcasting
groupedgemm.setProblem(m, n, k, b, epilogue, inputs);
NVTE_CHECK_HIPBLASLT(groupedgemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult));
if (heuristicResult.empty()) {
std::cerr << "No valid solution found!" << std::endl;
return;
}
// Make sure to initialize everytime the algo changes
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, nullptr, true, stream));
NVTE_CHECK_CUDA(hipEventSynchronize(host_event));
// Get the default values from the grouepdgemm object
groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
groupedgemm.getDefaultValueForDeviceUserArguments(host_args);
if(stream != device_stream) {
NVTE_CHECK_CUDA(hipStreamWaitEvent(stream, device_event, 0));
}
// Copy them to device memory
// hipblaslt_ext::UserArguments* d_userArgs;
// NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
NVTE_CHECK_CUDA(hipMemcpy(d_userArgs, userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments),
hipMemcpyHostToDevice));
NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
// NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
// NVTE_CHECK_CUDA(hipFree(userArgs));
NVTE_CHECK_CUDA(hipMemcpyAsync(device_args, host_args, m.size() * sizeof(hipblaslt_ext::UserArguments), hipMemcpyHostToDevice, stream));
NVTE_CHECK_CUDA(hipEventRecord(host_event, stream));
NVTE_CHECK_HIPBLASLT(groupedgemm.run(device_args, stream));
device_user_args.setStream(stream);
NVTE_CHECK_CUDA(hipEventRecord(device_event, stream));
}
#endif //USE_HIPBLASLT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment