Commit 1a8114bf authored by fengzch-das's avatar fengzch-das
Browse files

hipify code

parent c0177256
Pipeline #3049 canceled with stages
......@@ -12,8 +12,8 @@ public:
spdlog::info("Initializing QuantizedGEMM");
size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
checkCUDA(hipDeviceSetLimit(hipLimitStackSize, 8192));
checkCUDA(hipDeviceGetLimit(&val, hipLimitStackSize));
spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W4A4>((int)in_features,
......@@ -42,7 +42,7 @@ public:
std::string dumpTensorBF16(Tensor x) {
std::stringstream ss;
for (int i = 0; i < 256; i++) {
ss << spdlog::fmt_lib::format("{:.3f} ", (float)(x.data_ptr<__nv_bfloat16>()[i]));
ss << spdlog::fmt_lib::format("{:.3f} ", (float)(x.data_ptr<__hip_bfloat16>()[i]));
}
ss << std::endl;
return ss.str();
......
......@@ -12,8 +12,8 @@ public:
spdlog::info("Initializing QuantizedGEMM88");
size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
checkCUDA(hipDeviceSetLimit(hipLimitStackSize, 8192));
checkCUDA(hipDeviceGetLimit(&val, hipLimitStackSize));
spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W8A8>(
......
......@@ -8,27 +8,27 @@ namespace nunchaku::utils {
void set_cuda_stack_limit(int64_t newval) {
size_t val = 0;
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, (size_t)newval));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
checkCUDA(hipDeviceSetLimit(hipLimitStackSize, (size_t)newval));
checkCUDA(hipDeviceGetLimit(&val, hipLimitStackSize));
spdlog::debug("Stack={}", val);
}
void disable_memory_auto_release() {
int device;
checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
checkCUDA(hipGetDevice(&device));
hipMemPool_t mempool;
checkCUDA(hipDeviceGetDefaultMemPool(&mempool, device));
uint64_t threshold = UINT64_MAX;
checkCUDA(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
checkCUDA(hipMemPoolSetAttribute(mempool, hipMemPoolAttrReleaseThreshold, &threshold));
}
void trim_memory() {
int device;
checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
checkCUDA(hipGetDevice(&device));
hipMemPool_t mempool;
checkCUDA(hipDeviceGetDefaultMemPool(&mempool, device));
size_t bytesToKeep = 0;
checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep));
checkCUDA(hipMemPoolTrimTo(mempool, bytesToKeep));
}
void set_faster_i2f_mode(std::string mode) {
......
This diff is collapsed.
......@@ -5,7 +5,7 @@
#include "kernels/awq/gemv_awq.h"
#include "kernels/dwconv.h"
#include <nvtx3/nvToolsExt.h>
#include <nvtx3/roctracer/roctx.h>
using namespace nunchaku;
......@@ -117,7 +117,7 @@ GEMM_W4A4::GEMM_W4A4(
wtscale, "wtscale", ParamFlags::Optional)(wcscales, "wcscales", ParamFlags::Optional);
#if NO_LORA_FUSION
checkCUBLAS(cublasCreate(&handle));
checkCUBLAS(hipblasCreate(&handle));
#endif
}
......@@ -140,7 +140,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
} else if (key == "wtscale") {
assert(src.numel() == 1);
if (src.dtype() == Tensor::BF16) {
*dst.data_ptr<float>() = float(*src.data_ptr<__nv_bfloat16>());
*dst.data_ptr<float>() = float(*src.data_ptr<__hip_bfloat16>());
} else if (src.dtype() == Tensor::FP16) {
*dst.data_ptr<float>() = float(*src.data_ptr<half>());
} else if (src.dtype() == Tensor::FP32) {
......@@ -242,15 +242,15 @@ void GEMM_W4A4::forward(Tensor x,
qact.is_unsigned,
this->lora_scales);
nvtxRangePushA("LoraUp");
roctxRangePushA("LoraUp");
static const half one = 1.0;
static const half zero = 0.0;
// lora_up: [M, R] * [OC, R] => [M, OC]
// cublas view: [OC, R] * [M, R]^T
checkCUBLAS(cublasHgemm(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
checkCUBLAS(hipblasHgemm(handle,
HIPBLAS_OP_T,
HIPBLAS_OP_N,
this->out_features,
M,
this->lora_rank,
......@@ -263,7 +263,7 @@ void GEMM_W4A4::forward(Tensor x,
out.data_ptr<half>(),
this->out_features));
nvtxRangePop();
roctxRangePop();
#endif
}
......@@ -380,7 +380,7 @@ GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *
qact.is_unsigned,
this->lora_scales);
nvtxRangePushA("LoraUp");
roctxRangePushA("LoraUp");
static const half one = 1.0;
static const half zero = 0.0;
......@@ -388,9 +388,9 @@ GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *
// lora_up: [M, R] * [OC, R]^T => [M, OC]
// cublas view: [R, OC]^T * [R, M] => [OC, M]
// lora_up layout wrong?
checkCUBLAS(cublasHgemm(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
checkCUBLAS(hipblasHgemm(handle,
HIPBLAS_OP_T,
HIPBLAS_OP_N,
this->out_features,
M,
this->lora_rank,
......@@ -403,16 +403,16 @@ GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *
out.data_ptr<half>(),
this->out_features));
nvtxRangePop();
roctxRangePop();
if (fuse == FuseOptions::GELU_QUANT) {
nvtxRangePushA("LoraDown");
roctxRangePushA("LoraDown");
// IC is for next lora (OC of this layer)
// lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M] => [R, M]
checkCUBLAS(cublasHgemm(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
checkCUBLAS(hipblasHgemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
this->lora_rank,
M,
this->out_features,
......@@ -427,7 +427,7 @@ GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *
out = {};
nvtxRangePop();
roctxRangePop();
}
#endif
......@@ -473,13 +473,13 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
static const half one = 1.0;
static const half zero = 0.0;
nvtxRangePushA("LoraDown");
roctxRangePushA("LoraDown");
// lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M]
checkCUBLAS(cublasHgemm(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
checkCUBLAS(hipblasHgemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
this->lora_rank,
M,
this->in_features,
......@@ -492,7 +492,7 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
qact.lora_act.data_ptr<half>(),
this->lora_rank));
nvtxRangePop();
roctxRangePop();
kernels::quantize_w4a4_act(x, qact.act, qact.ascales);
......
......@@ -116,7 +116,7 @@ public:
Tensor wtscale;
Tensor wcscales;
cublasHandle_t handle;
hipblasHandle_t handle;
};
class GEMM_W8A8 : public Module {
......
......@@ -258,7 +258,7 @@ private:
waitEvent(eventLoadDone.get());
funcCompute(layer);
nextComputeDone = std::make_unique<CUDAEventWrapper>();
checkCUDA(cudaEventRecord(nextComputeDone->event, getCurrentCUDAStream()));
checkCUDA(hipEventRecord(nextComputeDone->event, getCurrentHIPStreamMasqueradingAsCUDA()));
workaroundFlush();
}
......@@ -272,7 +272,7 @@ private:
funcLoad(layer + 1);
}
nextLoadDone = std::make_unique<CUDAEventWrapper>();
checkCUDA(cudaEventRecord(nextLoadDone->event, getCurrentCUDAStream()));
checkCUDA(hipEventRecord(nextLoadDone->event, getCurrentHIPStreamMasqueradingAsCUDA()));
workaroundFlush();
}
......@@ -287,7 +287,7 @@ private:
if (!event) {
return;
}
checkCUDA(cudaStreamWaitEvent(getCurrentCUDAStream(), event->event));
checkCUDA(hipStreamWaitEvent(getCurrentHIPStreamMasqueradingAsCUDA(), event->event));
}
// WDDM prevents multiple streams run concurrently
......@@ -312,12 +312,12 @@ private:
if (!needWorkaround) {
return;
}
cudaStreamQuery(getCurrentCUDAStream());
hipStreamQuery(getCurrentHIPStreamMasqueradingAsCUDA());
}
void workaroundSynchronize() {
if (!needWorkaround) {
return;
}
checkCUDA(cudaEventSynchronize(eventComputeDone->event));
checkCUDA(hipEventSynchronize(eventComputeDone->event));
}
};
......@@ -5,7 +5,7 @@
#include "flash_api.h"
#include "kernels/misc_kernels.h"
#include <nvtx3/nvToolsExt.h>
#include <nvtx3/roctracer/roctx.h>
using spdlog::fmt_lib::format;
using namespace nunchaku;
......@@ -241,9 +241,9 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
bool pag,
bool cfg) {
nvtxRangePushA("SanaLinearTransformerBlock");
roctxRangePushA("SanaLinearTransformerBlock");
nvtxRangePushA("chunk");
roctxRangePushA("chunk");
// Tensor ones = Tensor::ones({hidden_size}, Tensor::FP16, x.device());
......@@ -262,10 +262,10 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = chunked;
// auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(timestep);
nvtxRangePop();
roctxRangePop();
{
nvtxRangePushA("LinearAttention");
roctxRangePushA("LinearAttention");
Tensor residual = hidden_states;
Tensor norm_hidden_states = norm1.forward(hidden_states);
......@@ -279,11 +279,11 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
hidden_states = attn_output;
nvtxRangePop();
roctxRangePop();
}
{
nvtxRangePushA("CrossAttention");
roctxRangePushA("CrossAttention");
debug("norm_hidden_states_cross", hidden_states);
Tensor attn_output = cross_attn.forward(hidden_states, encoder_hidden_states, cu_seqlens_img, cu_seqlens_txt);
......@@ -293,11 +293,11 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
hidden_states = attn_output;
nvtxRangePop();
roctxRangePop();
}
{
nvtxRangePushA("Feed-forward");
roctxRangePushA("Feed-forward");
debug("hidden_states_ff", hidden_states);
Tensor norm_hidden_states = norm2.forward(hidden_states);
......@@ -311,10 +311,10 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
hidden_states = ff_output;
nvtxRangePop();
roctxRangePop();
}
nvtxRangePop();
roctxRangePop();
debug("hidden_states_out", hidden_states);
......
......@@ -121,15 +121,15 @@ SafeTensors::SafeTensors(const std::string &filename) {
auto methodPrivate = [&]() {
this->mapped = std::make_unique<MMapImplPrivate>(filename);
checkCUDA(
cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable));
hipHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), hipHostRegisterPortable));
this->hostRegistered = true;
this->memoryPinned = true;
};
auto methodMio = [&]() {
this->mapped = std::make_unique<MMapImplMio>(filename);
checkCUDA(cudaHostRegister(const_cast<char *>(this->mapped->data()),
checkCUDA(hipHostRegister(const_cast<char *>(this->mapped->data()),
this->mapped->size(),
cudaHostRegisterPortable | cudaHostRegisterReadOnly));
hipHostRegisterPortable | hipHostRegisterReadOnly));
this->hostRegistered = true;
this->memoryPinned = true;
};
......@@ -183,8 +183,8 @@ SafeTensors::SafeTensors(const std::string &filename) {
SafeTensors::~SafeTensors() {
if (this->hostRegistered) {
if (cudaHostUnregister(const_cast<char *>(this->mapped->data())) != cudaSuccess) {
spdlog::warn("cudaHostUnregister failed: {}", cudaGetErrorString(cudaGetLastError()));
if (hipHostUnregister(const_cast<char *>(this->mapped->data())) != hipSuccess) {
spdlog::warn("hipHostUnregister failed: {}", hipGetErrorString(hipGetLastError()));
}
}
}
......
......@@ -9,17 +9,17 @@ public:
this->size = size;
this->device.type = Device::CPU;
this->ptr = ptr;
// auto ret = cudaHostRegister(ptr, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
// if (ret == cudaSuccess) {
// auto ret = hipHostRegister(ptr, size, hipHostRegisterPortable | hipHostRegisterReadOnly);
// if (ret == hipSuccess) {
// this->registered = true;
// } else {
// log(std::format("cudaHostRegister failed at {:p} (size={}): {}", ptr, size,
// cudaGetErrorString(cudaGetLastError()))); this->registered = false;
// log(std::format("hipHostRegister failed at {:p} (size={}): {}", ptr, size,
// hipGetErrorString(hipGetLastError()))); this->registered = false;
// }
}
virtual ~BufferMMap() {
// if (registered) {
// checkCUDA(cudaHostUnregister(ptr));
// checkCUDA(hipHostUnregister(ptr));
// }
}
......
#include "hip/hip_runtime.h"
#pragma once
#include "common.h"
......@@ -74,10 +75,10 @@ public:
BufferHost(size_t size) {
this->size = size;
this->device.type = Device::CPU;
checkCUDA(cudaHostAlloc(&this->ptr, size, cudaHostAllocPortable));
checkCUDA(hipHostMalloc(&this->ptr, size, hipHostMallocPortable));
}
virtual ~BufferHost() {
checkCUDA(cudaFreeHost(this->ptr));
checkCUDA(hipHostFree(this->ptr));
}
};
......@@ -86,20 +87,20 @@ public:
BufferCUDA(size_t size) {
this->size = size;
this->device.type = Device::CUDA;
// checkCUDA(cudaGetDevice(&this->device.idx));
// checkCUDA(hipGetDevice(&this->device.idx));
this->device.idx = CUDADeviceContext::getDevice();
if (size == 0) {
this->ptr = nullptr;
}
// TODO: buffer used in multiple streams?
checkCUDA(cudaMallocAsync(&this->ptr, size, getCurrentCUDAStream()));
checkCUDA(hipMallocAsync(&this->ptr, size, getCurrentHIPStreamMasqueradingAsCUDA()));
}
virtual ~BufferCUDA() {
if (this->size == 0) {
assert(!this->ptr);
return;
}
checkCUDA(cudaFreeAsync(this->ptr, getCurrentCUDAStream()));
checkCUDA(hipFreeAsync(this->ptr, getCurrentHIPStreamMasqueradingAsCUDA()));
}
virtual bool isAsyncBuffer() override {
return true;
......@@ -111,11 +112,11 @@ public:
BufferCUDASync(size_t size) {
this->size = size;
this->device.type = Device::CUDA;
checkCUDA(cudaGetDevice(&this->device.idx));
checkCUDA(cudaMalloc(&this->ptr, size));
checkCUDA(hipGetDevice(&this->device.idx));
checkCUDA(hipMalloc(&this->ptr, size));
}
virtual ~BufferCUDASync() {
checkCUDA(cudaFree(this->ptr));
checkCUDA(hipFree(this->ptr));
}
};
......@@ -415,8 +416,8 @@ public:
Tensor &zero_() {
assert(this->is_contiguous());
checkCUDA(cudaMemsetAsync(
data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size(), getCurrentCUDAStream()));
checkCUDA(hipMemsetAsync(
data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size(), getCurrentHIPStreamMasqueradingAsCUDA()));
return *this;
}
Tensor &copy_(Tensor other) {
......@@ -444,13 +445,13 @@ public:
return *this;
}
lockBuffer(this->buffer, getCurrentCUDAStream());
lockBuffer(other.buffer, getCurrentCUDAStream());
checkCUDA(cudaMemcpyAsync(data_ptr<char>(),
lockBuffer(this->buffer, getCurrentHIPStreamMasqueradingAsCUDA());
lockBuffer(other.buffer, getCurrentHIPStreamMasqueradingAsCUDA());
checkCUDA(hipMemcpyAsync(data_ptr<char>(),
other.data_ptr<char>(),
shape.size() * scalar_size(),
getCopyKind(this->device(), other.device()),
getCurrentCUDAStream()));
getCurrentHIPStreamMasqueradingAsCUDA()));
return *this;
}
......@@ -487,7 +488,7 @@ public:
} else if (device.type == Device::CUDA) {
CUDADeviceContext ctx(device.idx);
checkCUDA(
cudaMemsetAsync(result.buffer->getPtr(), 0xCC, result.buffer->getSize(), getCurrentCUDAStream()));
hipMemsetAsync(result.buffer->getPtr(), 0xCC, result.buffer->getSize(), getCurrentHIPStreamMasqueradingAsCUDA()));
}
}
......@@ -502,7 +503,7 @@ public:
static Tensor ones(TensorShape shape, ScalarType scalarType, Device device) {
Tensor result = allocate(shape, scalarType, device);
// FIXME FIXME FIXME
checkCUDA(cudaMemsetAsync(result.buffer->getPtr(), 1, result.buffer->getSize(), getCurrentCUDAStream()));
checkCUDA(hipMemsetAsync(result.buffer->getPtr(), 1, result.buffer->getSize(), getCurrentHIPStreamMasqueradingAsCUDA()));
return result;
}
static Tensor
......@@ -522,18 +523,18 @@ public:
Tensor result = allocate(this->shape.dataExtent, this->scalarType, device);
result.copy_(*this);
// lockBuffer(this->buffer, getCurrentCUDAStream());
// lockBuffer(result.buffer, getCurrentCUDAStream());
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault,
// getCurrentCUDAStream())); if (this->device().type == Device::CPU && device.type == Device::CUDA) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyHostToDevice, getCurrentCUDAStream()));
// lockBuffer(this->buffer, getCurrentHIPStreamMasqueradingAsCUDA());
// lockBuffer(result.buffer, getCurrentHIPStreamMasqueradingAsCUDA());
// checkCUDA(hipMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), hipMemcpyDefault,
// getCurrentHIPStreamMasqueradingAsCUDA())); if (this->device().type == Device::CPU && device.type == Device::CUDA) {
// checkCUDA(hipMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// hipMemcpyHostToDevice, getCurrentHIPStreamMasqueradingAsCUDA()));
// } else if (this->device().type == Device::CUDA && device.type == Device::CPU) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyDeviceToHost, getCurrentCUDAStream()));
// checkCUDA(hipMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// hipMemcpyDeviceToHost, getCurrentHIPStreamMasqueradingAsCUDA()));
// } else {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyDefault, getCurrentCUDAStream()));
// checkCUDA(hipMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// hipMemcpyDefault, getCurrentHIPStreamMasqueradingAsCUDA()));
// }
return result;
}
......@@ -548,38 +549,38 @@ public:
// auto shapeOut = this->shape;
// shapeOut[dim] = upper_bound - lower_bound;
// assert(dst.shape.data == shapeOut.data);
// checkCUDA(cudaMemcpy2DAsync(
// checkCUDA(hipMemcpy2DAsync(
// dst.
// ));
// }
private:
static cudaMemcpyKind getCopyKind(Device dst, Device src) {
static hipMemcpyKind getCopyKind(Device dst, Device src) {
if (src.type == Device::CPU && dst.type == Device::CUDA) {
return cudaMemcpyHostToDevice;
return hipMemcpyHostToDevice;
}
if (src.type == Device::CUDA && dst.type == Device::CPU) {
return cudaMemcpyDeviceToHost;
return hipMemcpyDeviceToHost;
}
if (src.type == Device::CUDA && dst.type == Device::CUDA) {
return cudaMemcpyDeviceToDevice;
return hipMemcpyDeviceToDevice;
}
if (src.type == Device::CPU && dst.type == Device::CPU) {
return cudaMemcpyHostToHost;
return hipMemcpyHostToHost;
}
return cudaMemcpyDefault;
return hipMemcpyDefault;
}
// static bool isAsyncBuffer(Buffer *buffer) {
// return dynamic_cast<BufferCUDA *>(buffer);
// }
static inline std::map<cudaStream_t, std::set<std::shared_ptr<Buffer>>> lockedBuffers;
static inline std::map<hipStream_t, std::set<std::shared_ptr<Buffer>>> lockedBuffers;
public:
// before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU
// completes
static void lockBuffer(std::shared_ptr<Buffer> buffer, cudaStream_t stream) {
static void lockBuffer(std::shared_ptr<Buffer> buffer, hipStream_t stream) {
if (!buffer->isAsyncBuffer()) {
lockedBuffers[stream].insert(buffer);
}
......@@ -589,16 +590,16 @@ public:
static void unlockBuffers() {
lockedBuffers.clear();
}
static void unlockBuffers(cudaStream_t stream) {
static void unlockBuffers(hipStream_t stream) {
lockedBuffers[stream].clear();
}
static void synchronizeDevice() {
checkCUDA(cudaDeviceSynchronize());
checkCUDA(hipDeviceSynchronize());
unlockBuffers();
}
static void synchronizeStream(cudaStream_t stream) {
checkCUDA(cudaStreamSynchronize(stream));
static void synchronizeStream(hipStream_t stream) {
checkCUDA(hipStreamSynchronize(stream));
unlockBuffers(stream);
}
};
......
......@@ -19,47 +19,47 @@
#include <optional>
#include <chrono>
#include <functional>
#include <cuda_runtime_api.h>
#include <cublas_v2.h>
#include <hip/hip_runtime_api.h>
#include <hipblas/hipblas.h>
#include <spdlog/spdlog.h>
class CUDAError : public std::runtime_error {
public:
CUDAError(cudaError_t errorCode, std::source_location location)
CUDAError(hipError_t errorCode, std::source_location location)
: std::runtime_error(format(errorCode, location)), errorCode(errorCode), location(location) {}
public:
const cudaError_t errorCode;
const hipError_t errorCode;
const std::source_location location;
private:
static std::string format(cudaError_t errorCode, std::source_location location) {
static std::string format(hipError_t errorCode, std::source_location location) {
return spdlog::fmt_lib::format(
"CUDA error: {} (at {}:{})", cudaGetErrorString(errorCode), location.file_name(), location.line());
"CUDA error: {} (at {}:{})", hipGetErrorString(errorCode), location.file_name(), location.line());
}
};
inline cudaError_t checkCUDA(cudaError_t retValue,
inline hipError_t checkCUDA(hipError_t retValue,
const std::source_location location = std::source_location::current()) {
if (retValue != cudaSuccess) {
(void)cudaGetLastError();
if (retValue != hipSuccess) {
(void)hipGetLastError();
throw CUDAError(retValue, location);
}
return retValue;
}
inline cublasStatus_t checkCUBLAS(cublasStatus_t retValue,
inline hipblasStatus_t checkCUBLAS(hipblasStatus_t retValue,
const std::source_location location = std::source_location::current()) {
if (retValue != CUBLAS_STATUS_SUCCESS) {
if (retValue != HIPBLAS_STATUS_SUCCESS) {
throw std::runtime_error(spdlog::fmt_lib::format(
"CUBLAS error: {} (at {}:{})", cublasGetStatusString(retValue), location.file_name(), location.line()));
"CUBLAS error: {} (at {}:{})", rocblas_status_to_string(retValue), location.file_name(), location.line()));
}
return retValue;
}
inline thread_local std::stack<cudaStream_t> stackCUDAStreams;
inline thread_local std::stack<hipStream_t> stackCUDAStreams;
inline cudaStream_t getCurrentCUDAStream() {
inline hipStream_t getCurrentHIPStreamMasqueradingAsCUDA() {
if (stackCUDAStreams.empty()) {
return 0;
}
......@@ -67,9 +67,9 @@ inline cudaStream_t getCurrentCUDAStream() {
}
struct CUDAStreamContext {
cudaStream_t stream;
hipStream_t stream;
CUDAStreamContext(cudaStream_t stream) : stream(stream) {
CUDAStreamContext(hipStream_t stream) : stream(stream) {
stackCUDAStreams.push(stream);
}
CUDAStreamContext(const CUDAStreamContext &) = delete;
......@@ -82,30 +82,30 @@ struct CUDAStreamContext {
};
struct CUDAStreamWrapper {
cudaStream_t stream;
hipStream_t stream;
CUDAStreamWrapper() {
checkCUDA(cudaStreamCreate(&stream));
checkCUDA(hipStreamCreate(&stream));
}
CUDAStreamWrapper(const CUDAStreamWrapper &) = delete;
CUDAStreamWrapper(CUDAStreamWrapper &&) = delete;
~CUDAStreamWrapper() {
checkCUDA(cudaStreamDestroy(stream));
checkCUDA(hipStreamDestroy(stream));
}
};
struct CUDAEventWrapper {
cudaEvent_t event;
hipEvent_t event;
CUDAEventWrapper(unsigned int flags = cudaEventDefault) {
checkCUDA(cudaEventCreateWithFlags(&event, flags));
CUDAEventWrapper(unsigned int flags = hipEventDefault) {
checkCUDA(hipEventCreateWithFlags(&event, flags));
}
CUDAEventWrapper(const CUDAEventWrapper &) = delete;
CUDAEventWrapper(CUDAEventWrapper &&) = delete;
~CUDAEventWrapper() {
checkCUDA(cudaEventDestroy(event));
checkCUDA(hipEventDestroy(event));
}
};
......@@ -162,7 +162,7 @@ public:
static int getDevice() {
int idx = -1;
if (cacheDisabled() || currentDeviceCache < 0) {
checkCUDA(cudaGetDevice(&idx));
checkCUDA(hipGetDevice(&idx));
} else {
idx = currentDeviceCache;
}
......@@ -177,7 +177,7 @@ private:
if (!cacheDisabled() && currentDeviceCache == idx) {
return;
}
checkCUDA(cudaSetDevice(idx));
checkCUDA(hipSetDevice(idx));
currentDeviceCache = cacheDisabled() ? -1 : idx;
}
......@@ -190,13 +190,13 @@ private:
}
};
inline cudaDeviceProp *getCurrentDeviceProperties() {
static thread_local std::map<int, cudaDeviceProp> props;
inline hipDeviceProp_t *getCurrentDeviceProperties() {
static thread_local std::map<int, hipDeviceProp_t> props;
int deviceId = CUDADeviceContext::getDevice();
if (!props.contains(deviceId)) {
cudaDeviceProp prop;
checkCUDA(cudaGetDeviceProperties(&prop, deviceId));
hipDeviceProp_t prop;
checkCUDA(hipGetDeviceProperties(&prop, deviceId));
props[deviceId] = prop;
}
return &props.at(deviceId);
......@@ -217,16 +217,16 @@ constexpr int log2Up(T value) {
}
struct CUBLASWrapper {
cublasHandle_t handle = nullptr;
hipblasHandle_t handle = nullptr;
CUBLASWrapper() {
checkCUBLAS(cublasCreate(&handle));
checkCUBLAS(hipblasCreate(&handle));
}
CUBLASWrapper(CUBLASWrapper &&) = delete;
CUBLASWrapper(const CUBLASWrapper &&) = delete;
~CUBLASWrapper() {
if (handle) {
checkCUBLAS(cublasDestroy(handle));
checkCUBLAS(hipblasDestroy(handle));
}
}
};
......
#include "torch.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/hip/HIPContext.h>
using spdlog::fmt_lib::format;
......@@ -37,7 +37,7 @@ Tensor from_torch(at::Tensor input) {
result.scalarType = mapType.at(input.scalar_type());
result.buffer = std::make_shared<BufferTorchTensor>(std::move(input));
Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
Tensor::lockBuffer(result.buffer, getCurrentHIPStreamMasqueradingAsCUDA());
return result;
}
......@@ -76,10 +76,10 @@ at::Tensor to_torch(Tensor input) {
}
TorchOpContext::TorchOpContext() {
stackCUDAStreams.push(at::cuda::getCurrentCUDAStream().stream());
stackCUDAStreams.push(at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream());
}
TorchOpContext::~TorchOpContext() {
assert(stackCUDAStreams.top() == at::cuda::getCurrentCUDAStream().stream());
assert(stackCUDAStreams.top() == at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream());
stackCUDAStreams.pop();
}
#include "hip/hip_runtime.h"
#include "activation_kernels_impl.cuh"
#include "activation_kernels.h"
#include "dispatch_utils.h"
......@@ -8,10 +9,10 @@
int num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = getCurrentCUDAStream(); \
const hipStream_t stream = getCurrentHIPStreamMasqueradingAsCUDA(); \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
hipLaunchKernelGGL(( vllm::activation_kernel<scalar_t, KERNEL<scalar_t>>) \
, dim3(grid), dim3(block), 0, stream, out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
});
void silu_and_mul(Tensor &out, // [..., d]
......@@ -21,14 +22,14 @@ void silu_and_mul(Tensor &out, // [..., d]
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
const hipStream_t stream = getCurrentHIPStreamMasqueradingAsCUDA();
// dispatchFloat(input.scalar_type(), [&]<typename scalar_t>() {
// vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
// out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
// });
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_kernel", [&] {
vllm::silu_and_mul_kernel<scalar_t>
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
hipLaunchKernelGGL(( vllm::silu_and_mul_kernel<scalar_t>)
, dim3(grid), dim3(block), 0, stream, out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
});
}
......@@ -41,8 +42,8 @@ void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
vllm::dequant_silu_and_mul_quant_kernel<float, false><<<grid, block, 0, stream>>>(
const hipStream_t stream = getCurrentHIPStreamMasqueradingAsCUDA();
hipLaunchKernelGGL(( vllm::dequant_silu_and_mul_quant_kernel<float, false>), dim3(grid), dim3(block), 0, stream,
out.data_ptr<int8_t>(), input.data_ptr<int32_t>(), d, scale_gate, scale_up, scale_out);
}
......@@ -57,8 +58,8 @@ void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = getCurrentCUDAStream();
vllm::dequant_silu_and_mul_quant_kernel<float *, true><<<grid, block, 0, stream>>>(out.data_ptr<int8_t>(),
const hipStream_t stream = getCurrentHIPStreamMasqueradingAsCUDA();
hipLaunchKernelGGL(( vllm::dequant_silu_and_mul_quant_kernel<float *, true>), dim3(grid), dim3(block), 0, stream, out.data_ptr<int8_t>(),
input.data_ptr<int32_t>(),
d,
scale_gate,
......
#include "hip/hip_runtime.h"
#include "utils.cuh"
#include "reduction_utils.cuh"
......
......@@ -11,7 +11,7 @@ https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutl
*/
#pragma once
#include <cuda_fp16.h>
#include <hip/hip_fp16.h>
#include <cstdint>
__forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) {
......@@ -75,14 +75,14 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uin
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
}
__forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
__forceinline__ __device__ void dequantize_s4_to_fp16x2(__hip_bfloat162 const &source, uint4 *result) {
// dequantize_s4_to_fp16x2(reinterpret_cast<const half2 &>(source), result);
// *reinterpret_cast<__nv_bfloat162 *>(&result->x) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2
// *>(&result->x)); *reinterpret_cast<__nv_bfloat162 *>(&result->y) =
// cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->y)); *reinterpret_cast<__nv_bfloat162
// *>(&result->z) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->z));
// *reinterpret_cast<__nv_bfloat162 *>(&result->w) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2
// *reinterpret_cast<__hip_bfloat162 *>(&result->x) = cuda_cast<__hip_bfloat162>(*reinterpret_cast<half2
// *>(&result->x)); *reinterpret_cast<__hip_bfloat162 *>(&result->y) =
// cuda_cast<__hip_bfloat162>(*reinterpret_cast<half2 *>(&result->y)); *reinterpret_cast<__hip_bfloat162
// *>(&result->z) = cuda_cast<__hip_bfloat162>(*reinterpret_cast<half2 *>(&result->z));
// *reinterpret_cast<__hip_bfloat162 *>(&result->w) = cuda_cast<__hip_bfloat162>(*reinterpret_cast<half2
// *>(&result->w));
// return;
......
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "hip/hip_runtime.h"
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include "semaphore.h"
#include "gemm_awq.h"
// #include "../../../nunchaku/csrc/quantization/dequantize.cuh"
......@@ -46,8 +47,8 @@
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK); \
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
auto kernel_func = gemm_w4a16_T1<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G, SPLITK>; \
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>( \
hipFuncSetAttribute(kernel_func, hipFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
hipLaunchKernelGGL(( kernel_func), dim3(num_blocks), dim3(threads_per_block), kSmemByteSize, 0, \
in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels);
template<int N>
......@@ -90,8 +91,8 @@ __inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr) {
template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types.");
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __hip_bfloat16>::value,
"ldmatrix_m8n8_x4_b16 supports only half or __hip_bfloat16 types.");
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
......@@ -103,8 +104,8 @@ __inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, u
template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types.");
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __hip_bfloat16>::value,
"ldmatrix_m8n8_x4_trans_b16 supports only half or __hip_bfloat16 types.");
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
......@@ -149,7 +150,7 @@ __device__ __inline__ void mma_m16n8k16<half>(float *C_warp, half *A_shared_warp
template<>
__device__ __inline__ void
mma_m16n8k16<__nv_bfloat16>(float *C_warp, __nv_bfloat16 *A_shared_warp, __nv_bfloat16 *B_shared_warp) {
mma_m16n8k16<__hip_bfloat16>(float *C_warp, __hip_bfloat16 *A_shared_warp, __hip_bfloat16 *B_shared_warp) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
......@@ -378,7 +379,7 @@ __global__ void gemm_w4a16_T1(f16_t *__restrict__ A,
int M,
int N,
int K) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__DTK_ARCH__) && __DTK_ARCH__ < 800
trap_unsupported_arch();
return;
#endif
......@@ -944,7 +945,7 @@ __global__ void gemm_w4a16_T2(f16_t *__restrict__ A,
int M,
int N,
int K) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if defined(__DTK_ARCH__) && __DTK_ARCH__ < 800
trap_unsupported_arch();
return;
#endif
......@@ -1277,12 +1278,12 @@ Tensor awq_gemm_forward_cuda(Tensor _in_feats, Tensor _kernel, Tensor _scales, T
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
hipFuncSetAttribute(kernel_func, hipFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
hipLaunchKernelGGL(( kernel_func), dim3(num_blocks), dim3(threads_per_block), kSmemByteSize, 0,
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
}
} else if (_in_feats.scalar_type() == Tensor::BF16) {
using f16_t = __nv_bfloat16;
using f16_t = __hip_bfloat16;
auto in_feats = reinterpret_cast<f16_t *>(_in_feats.data_ptr());
auto kernel = reinterpret_cast<f16_t *>(_kernel.data_ptr<int16_t>());
......@@ -1357,8 +1358,8 @@ Tensor awq_gemm_forward_cuda(Tensor _in_feats, Tensor _kernel, Tensor _scales, T
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
hipFuncSetAttribute(kernel_func, hipFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
hipLaunchKernelGGL(( kernel_func), dim3(num_blocks), dim3(threads_per_block), kSmemByteSize, 0,
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
}
} else {
......
#include "hip/hip_runtime.h"
/*
* Modified from NVIDIA
* [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
......@@ -30,8 +31,8 @@
#include "../utils.cuh"
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <stdio.h>
#include "dequantize.cuh"
......@@ -80,7 +81,7 @@ __device__ __forceinline__ packed_as<half, 2>::type half2half2<half>(half x) {
}
template<>
__device__ __forceinline__ packed_as<__nv_bfloat16, 2>::type half2half2<__nv_bfloat16>(__nv_bfloat16 x) {
__device__ __forceinline__ packed_as<__hip_bfloat16, 2>::type half2half2<__hip_bfloat16>(__hip_bfloat16 x) {
return __bfloat162bfloat162(x);
}
......@@ -93,7 +94,7 @@ __device__ __forceinline__ float2 half22float2<half2>(half2 val) {
}
template<>
__device__ __forceinline__ float2 half22float2<__nv_bfloat162>(__nv_bfloat162 val) {
__device__ __forceinline__ float2 half22float2<__hip_bfloat162>(__hip_bfloat162 val) {
return __bfloat1622float2(val);
}
......@@ -106,8 +107,8 @@ __global__ void gemv_kernel(const half_t *inputs,
const int IC,
const int OC) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
if constexpr (std::is_same_v<half_t, __nv_bfloat16>) {
#if defined(__DTK_ARCH__) && __DTK_ARCH__ < 800
if constexpr (std::is_same_v<half_t, __hip_bfloat16>) {
trap_unsupported_arch();
return;
}
......@@ -282,10 +283,10 @@ Tensor gemv_awq(
return;
}
if constexpr (M > 0) {
gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE>
<<<num_blocks, num_threads, 0, getCurrentCUDAStream()>>>(
hipLaunchKernelGGL(( gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE>)
, dim3(num_blocks), dim3(num_threads), 0, getCurrentHIPStreamMasqueradingAsCUDA(),
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
checkCUDA(cudaGetLastError());
checkCUDA(hipGetLastError());
}
});
......
#include "hip/hip_runtime.h"
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
......@@ -55,7 +56,7 @@ public:
/// Permit fetching the synchronization mechanism early
__device__ void fetch() {
if (wait_thread) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
#if defined(__DTK_ARCH__) && __DTK_ARCH__ >= 700
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#else
asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
......@@ -82,7 +83,7 @@ public:
__syncthreads();
if (wait_thread) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
#if defined(__DTK_ARCH__) && __DTK_ARCH__ >= 700
asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#else
asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
......
......@@ -2,13 +2,13 @@
#include "common.h"
#include "Tensor.h"
#include <cuda_fp16.h>
#include <hip/hip_fp16.h>
template<typename F>
inline auto dispatchFloat(Tensor::ScalarType scalarType, F &&func) {
switch (scalarType) {
case Tensor::BF16:
return func.template operator()<__nv_bfloat16>();
return func.template operator()<__hip_bfloat16>();
case Tensor::FP16:
return func.template operator()<half>();
case Tensor::FP32:
......@@ -23,7 +23,7 @@ template<typename F>
inline auto dispatchFloat16(Tensor::ScalarType scalarType, F &&func) {
switch (scalarType) {
case Tensor::BF16:
return func.template operator()<__nv_bfloat16>();
return func.template operator()<__hip_bfloat16>();
case Tensor::FP16:
return func.template operator()<half>();
default:
......@@ -36,7 +36,7 @@ template<typename F>
inline auto dispatch(Tensor::ScalarType scalarType, F &&func) {
switch (scalarType) {
case Tensor::BF16:
return func.template operator()<__nv_bfloat16>();
return func.template operator()<__hip_bfloat16>();
case Tensor::FP16:
return func.template operator()<half>();
case Tensor::FP32:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment