Commit 7c282e2e authored by fengzch-das's avatar fengzch-das
Browse files

Merge branch 'revert-1a8114bf' into 'v1.0.2'

Revert "hipify code"

See merge request !1
parents 1a8114bf 0a7c8614
Pipeline #3051 canceled with stages
......@@ -12,8 +12,8 @@ public:
spdlog::info("Initializing QuantizedGEMM");
size_t val = 0;
checkCUDA(hipDeviceSetLimit(hipLimitStackSize, 8192));
checkCUDA(hipDeviceGetLimit(&val, hipLimitStackSize));
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W4A4>((int)in_features,
......@@ -42,7 +42,7 @@ public:
std::string dumpTensorBF16(Tensor x) {
std::stringstream ss;
for (int i = 0; i < 256; i++) {
ss << spdlog::fmt_lib::format("{:.3f} ", (float)(x.data_ptr<__hip_bfloat16>()[i]));
ss << spdlog::fmt_lib::format("{:.3f} ", (float)(x.data_ptr<__nv_bfloat16>()[i]));
}
ss << std::endl;
return ss.str();
......
......@@ -12,8 +12,8 @@ public:
spdlog::info("Initializing QuantizedGEMM88");
size_t val = 0;
checkCUDA(hipDeviceSetLimit(hipLimitStackSize, 8192));
checkCUDA(hipDeviceGetLimit(&val, hipLimitStackSize));
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, 8192));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val);
net = std::make_unique<GEMM_W8A8>(
......
......@@ -8,27 +8,27 @@ namespace nunchaku::utils {
void set_cuda_stack_limit(int64_t newval) {
size_t val = 0;
checkCUDA(hipDeviceSetLimit(hipLimitStackSize, (size_t)newval));
checkCUDA(hipDeviceGetLimit(&val, hipLimitStackSize));
checkCUDA(cudaDeviceSetLimit(cudaLimitStackSize, (size_t)newval));
checkCUDA(cudaDeviceGetLimit(&val, cudaLimitStackSize));
spdlog::debug("Stack={}", val);
}
void disable_memory_auto_release() {
int device;
checkCUDA(hipGetDevice(&device));
hipMemPool_t mempool;
checkCUDA(hipDeviceGetDefaultMemPool(&mempool, device));
checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
uint64_t threshold = UINT64_MAX;
checkCUDA(hipMemPoolSetAttribute(mempool, hipMemPoolAttrReleaseThreshold, &threshold));
checkCUDA(cudaMemPoolSetAttribute(mempool, cudaMemPoolAttrReleaseThreshold, &threshold));
}
void trim_memory() {
int device;
checkCUDA(hipGetDevice(&device));
hipMemPool_t mempool;
checkCUDA(hipDeviceGetDefaultMemPool(&mempool, device));
checkCUDA(cudaGetDevice(&device));
cudaMemPool_t mempool;
checkCUDA(cudaDeviceGetDefaultMemPool(&mempool, device));
size_t bytesToKeep = 0;
checkCUDA(hipMemPoolTrimTo(mempool, bytesToKeep));
checkCUDA(cudaMemPoolTrimTo(mempool, bytesToKeep));
}
void set_faster_i2f_mode(std::string mode) {
......
This diff is collapsed.
......@@ -5,7 +5,7 @@
#include "kernels/awq/gemv_awq.h"
#include "kernels/dwconv.h"
#include <nvtx3/roctracer/roctx.h>
#include <nvtx3/nvToolsExt.h>
using namespace nunchaku;
......@@ -117,7 +117,7 @@ GEMM_W4A4::GEMM_W4A4(
wtscale, "wtscale", ParamFlags::Optional)(wcscales, "wcscales", ParamFlags::Optional);
#if NO_LORA_FUSION
checkCUBLAS(hipblasCreate(&handle));
checkCUBLAS(cublasCreate(&handle));
#endif
}
......@@ -140,7 +140,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
} else if (key == "wtscale") {
assert(src.numel() == 1);
if (src.dtype() == Tensor::BF16) {
*dst.data_ptr<float>() = float(*src.data_ptr<__hip_bfloat16>());
*dst.data_ptr<float>() = float(*src.data_ptr<__nv_bfloat16>());
} else if (src.dtype() == Tensor::FP16) {
*dst.data_ptr<float>() = float(*src.data_ptr<half>());
} else if (src.dtype() == Tensor::FP32) {
......@@ -242,15 +242,15 @@ void GEMM_W4A4::forward(Tensor x,
qact.is_unsigned,
this->lora_scales);
roctxRangePushA("LoraUp");
nvtxRangePushA("LoraUp");
static const half one = 1.0;
static const half zero = 0.0;
// lora_up: [M, R] * [OC, R] => [M, OC]
// cublas view: [OC, R] * [M, R]^T
checkCUBLAS(hipblasHgemm(handle,
HIPBLAS_OP_T,
HIPBLAS_OP_N,
checkCUBLAS(cublasHgemm(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
this->out_features,
M,
this->lora_rank,
......@@ -263,7 +263,7 @@ void GEMM_W4A4::forward(Tensor x,
out.data_ptr<half>(),
this->out_features));
roctxRangePop();
nvtxRangePop();
#endif
}
......@@ -380,7 +380,7 @@ GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *
qact.is_unsigned,
this->lora_scales);
roctxRangePushA("LoraUp");
nvtxRangePushA("LoraUp");
static const half one = 1.0;
static const half zero = 0.0;
......@@ -388,9 +388,9 @@ GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *
// lora_up: [M, R] * [OC, R]^T => [M, OC]
// cublas view: [R, OC]^T * [R, M] => [OC, M]
// lora_up layout wrong?
checkCUBLAS(hipblasHgemm(handle,
HIPBLAS_OP_T,
HIPBLAS_OP_N,
checkCUBLAS(cublasHgemm(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
this->out_features,
M,
this->lora_rank,
......@@ -403,16 +403,16 @@ GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *
out.data_ptr<half>(),
this->out_features));
roctxRangePop();
nvtxRangePop();
if (fuse == FuseOptions::GELU_QUANT) {
roctxRangePushA("LoraDown");
nvtxRangePushA("LoraDown");
// IC is for next lora (OC of this layer)
// lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M] => [R, M]
checkCUBLAS(hipblasHgemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
checkCUBLAS(cublasHgemm(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
this->lora_rank,
M,
this->out_features,
......@@ -427,7 +427,7 @@ GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *
out = {};
roctxRangePop();
nvtxRangePop();
}
#endif
......@@ -473,13 +473,13 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
static const half one = 1.0;
static const half zero = 0.0;
roctxRangePushA("LoraDown");
nvtxRangePushA("LoraDown");
// lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M]
checkCUBLAS(hipblasHgemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
checkCUBLAS(cublasHgemm(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
this->lora_rank,
M,
this->in_features,
......@@ -492,7 +492,7 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
qact.lora_act.data_ptr<half>(),
this->lora_rank));
roctxRangePop();
nvtxRangePop();
kernels::quantize_w4a4_act(x, qact.act, qact.ascales);
......
......@@ -116,7 +116,7 @@ public:
Tensor wtscale;
Tensor wcscales;
hipblasHandle_t handle;
cublasHandle_t handle;
};
class GEMM_W8A8 : public Module {
......
......@@ -258,7 +258,7 @@ private:
waitEvent(eventLoadDone.get());
funcCompute(layer);
nextComputeDone = std::make_unique<CUDAEventWrapper>();
checkCUDA(hipEventRecord(nextComputeDone->event, getCurrentHIPStreamMasqueradingAsCUDA()));
checkCUDA(cudaEventRecord(nextComputeDone->event, getCurrentCUDAStream()));
workaroundFlush();
}
......@@ -272,7 +272,7 @@ private:
funcLoad(layer + 1);
}
nextLoadDone = std::make_unique<CUDAEventWrapper>();
checkCUDA(hipEventRecord(nextLoadDone->event, getCurrentHIPStreamMasqueradingAsCUDA()));
checkCUDA(cudaEventRecord(nextLoadDone->event, getCurrentCUDAStream()));
workaroundFlush();
}
......@@ -287,7 +287,7 @@ private:
if (!event) {
return;
}
checkCUDA(hipStreamWaitEvent(getCurrentHIPStreamMasqueradingAsCUDA(), event->event));
checkCUDA(cudaStreamWaitEvent(getCurrentCUDAStream(), event->event));
}
// WDDM prevents multiple streams run concurrently
......@@ -312,12 +312,12 @@ private:
if (!needWorkaround) {
return;
}
hipStreamQuery(getCurrentHIPStreamMasqueradingAsCUDA());
cudaStreamQuery(getCurrentCUDAStream());
}
void workaroundSynchronize() {
if (!needWorkaround) {
return;
}
checkCUDA(hipEventSynchronize(eventComputeDone->event));
checkCUDA(cudaEventSynchronize(eventComputeDone->event));
}
};
......@@ -5,7 +5,7 @@
#include "flash_api.h"
#include "kernels/misc_kernels.h"
#include <nvtx3/roctracer/roctx.h>
#include <nvtx3/nvToolsExt.h>
using spdlog::fmt_lib::format;
using namespace nunchaku;
......@@ -241,9 +241,9 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
bool pag,
bool cfg) {
roctxRangePushA("SanaLinearTransformerBlock");
nvtxRangePushA("SanaLinearTransformerBlock");
roctxRangePushA("chunk");
nvtxRangePushA("chunk");
// Tensor ones = Tensor::ones({hidden_size}, Tensor::FP16, x.device());
......@@ -262,10 +262,10 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = chunked;
// auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(timestep);
roctxRangePop();
nvtxRangePop();
{
roctxRangePushA("LinearAttention");
nvtxRangePushA("LinearAttention");
Tensor residual = hidden_states;
Tensor norm_hidden_states = norm1.forward(hidden_states);
......@@ -279,11 +279,11 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
hidden_states = attn_output;
roctxRangePop();
nvtxRangePop();
}
{
roctxRangePushA("CrossAttention");
nvtxRangePushA("CrossAttention");
debug("norm_hidden_states_cross", hidden_states);
Tensor attn_output = cross_attn.forward(hidden_states, encoder_hidden_states, cu_seqlens_img, cu_seqlens_txt);
......@@ -293,11 +293,11 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
hidden_states = attn_output;
roctxRangePop();
nvtxRangePop();
}
{
roctxRangePushA("Feed-forward");
nvtxRangePushA("Feed-forward");
debug("hidden_states_ff", hidden_states);
Tensor norm_hidden_states = norm2.forward(hidden_states);
......@@ -311,10 +311,10 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
hidden_states = ff_output;
roctxRangePop();
nvtxRangePop();
}
roctxRangePop();
nvtxRangePop();
debug("hidden_states_out", hidden_states);
......
......@@ -121,15 +121,15 @@ SafeTensors::SafeTensors(const std::string &filename) {
auto methodPrivate = [&]() {
this->mapped = std::make_unique<MMapImplPrivate>(filename);
checkCUDA(
hipHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), hipHostRegisterPortable));
cudaHostRegister(const_cast<char *>(this->mapped->data()), this->mapped->size(), cudaHostRegisterPortable));
this->hostRegistered = true;
this->memoryPinned = true;
};
auto methodMio = [&]() {
this->mapped = std::make_unique<MMapImplMio>(filename);
checkCUDA(hipHostRegister(const_cast<char *>(this->mapped->data()),
checkCUDA(cudaHostRegister(const_cast<char *>(this->mapped->data()),
this->mapped->size(),
hipHostRegisterPortable | hipHostRegisterReadOnly));
cudaHostRegisterPortable | cudaHostRegisterReadOnly));
this->hostRegistered = true;
this->memoryPinned = true;
};
......@@ -183,8 +183,8 @@ SafeTensors::SafeTensors(const std::string &filename) {
SafeTensors::~SafeTensors() {
if (this->hostRegistered) {
if (hipHostUnregister(const_cast<char *>(this->mapped->data())) != hipSuccess) {
spdlog::warn("hipHostUnregister failed: {}", hipGetErrorString(hipGetLastError()));
if (cudaHostUnregister(const_cast<char *>(this->mapped->data())) != cudaSuccess) {
spdlog::warn("cudaHostUnregister failed: {}", cudaGetErrorString(cudaGetLastError()));
}
}
}
......
......@@ -9,17 +9,17 @@ public:
this->size = size;
this->device.type = Device::CPU;
this->ptr = ptr;
// auto ret = hipHostRegister(ptr, size, hipHostRegisterPortable | hipHostRegisterReadOnly);
// if (ret == hipSuccess) {
// auto ret = cudaHostRegister(ptr, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
// if (ret == cudaSuccess) {
// this->registered = true;
// } else {
// log(std::format("hipHostRegister failed at {:p} (size={}): {}", ptr, size,
// hipGetErrorString(hipGetLastError()))); this->registered = false;
// log(std::format("cudaHostRegister failed at {:p} (size={}): {}", ptr, size,
// cudaGetErrorString(cudaGetLastError()))); this->registered = false;
// }
}
virtual ~BufferMMap() {
// if (registered) {
// checkCUDA(hipHostUnregister(ptr));
// checkCUDA(cudaHostUnregister(ptr));
// }
}
......
#include "hip/hip_runtime.h"
#pragma once
#include "common.h"
......@@ -75,10 +74,10 @@ public:
BufferHost(size_t size) {
this->size = size;
this->device.type = Device::CPU;
checkCUDA(hipHostMalloc(&this->ptr, size, hipHostMallocPortable));
checkCUDA(cudaHostAlloc(&this->ptr, size, cudaHostAllocPortable));
}
virtual ~BufferHost() {
checkCUDA(hipHostFree(this->ptr));
checkCUDA(cudaFreeHost(this->ptr));
}
};
......@@ -87,20 +86,20 @@ public:
BufferCUDA(size_t size) {
this->size = size;
this->device.type = Device::CUDA;
// checkCUDA(hipGetDevice(&this->device.idx));
// checkCUDA(cudaGetDevice(&this->device.idx));
this->device.idx = CUDADeviceContext::getDevice();
if (size == 0) {
this->ptr = nullptr;
}
// TODO: buffer used in multiple streams?
checkCUDA(hipMallocAsync(&this->ptr, size, getCurrentHIPStreamMasqueradingAsCUDA()));
checkCUDA(cudaMallocAsync(&this->ptr, size, getCurrentCUDAStream()));
}
virtual ~BufferCUDA() {
if (this->size == 0) {
assert(!this->ptr);
return;
}
checkCUDA(hipFreeAsync(this->ptr, getCurrentHIPStreamMasqueradingAsCUDA()));
checkCUDA(cudaFreeAsync(this->ptr, getCurrentCUDAStream()));
}
virtual bool isAsyncBuffer() override {
return true;
......@@ -112,11 +111,11 @@ public:
BufferCUDASync(size_t size) {
this->size = size;
this->device.type = Device::CUDA;
checkCUDA(hipGetDevice(&this->device.idx));
checkCUDA(hipMalloc(&this->ptr, size));
checkCUDA(cudaGetDevice(&this->device.idx));
checkCUDA(cudaMalloc(&this->ptr, size));
}
virtual ~BufferCUDASync() {
checkCUDA(hipFree(this->ptr));
checkCUDA(cudaFree(this->ptr));
}
};
......@@ -416,8 +415,8 @@ public:
Tensor &zero_() {
assert(this->is_contiguous());
checkCUDA(hipMemsetAsync(
data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size(), getCurrentHIPStreamMasqueradingAsCUDA()));
checkCUDA(cudaMemsetAsync(
data_ptr<char>() + shape.offset * scalar_size(), 0, shape.size() * scalar_size(), getCurrentCUDAStream()));
return *this;
}
Tensor &copy_(Tensor other) {
......@@ -445,13 +444,13 @@ public:
return *this;
}
lockBuffer(this->buffer, getCurrentHIPStreamMasqueradingAsCUDA());
lockBuffer(other.buffer, getCurrentHIPStreamMasqueradingAsCUDA());
checkCUDA(hipMemcpyAsync(data_ptr<char>(),
lockBuffer(this->buffer, getCurrentCUDAStream());
lockBuffer(other.buffer, getCurrentCUDAStream());
checkCUDA(cudaMemcpyAsync(data_ptr<char>(),
other.data_ptr<char>(),
shape.size() * scalar_size(),
getCopyKind(this->device(), other.device()),
getCurrentHIPStreamMasqueradingAsCUDA()));
getCurrentCUDAStream()));
return *this;
}
......@@ -488,7 +487,7 @@ public:
} else if (device.type == Device::CUDA) {
CUDADeviceContext ctx(device.idx);
checkCUDA(
hipMemsetAsync(result.buffer->getPtr(), 0xCC, result.buffer->getSize(), getCurrentHIPStreamMasqueradingAsCUDA()));
cudaMemsetAsync(result.buffer->getPtr(), 0xCC, result.buffer->getSize(), getCurrentCUDAStream()));
}
}
......@@ -503,7 +502,7 @@ public:
static Tensor ones(TensorShape shape, ScalarType scalarType, Device device) {
Tensor result = allocate(shape, scalarType, device);
// FIXME FIXME FIXME
checkCUDA(hipMemsetAsync(result.buffer->getPtr(), 1, result.buffer->getSize(), getCurrentHIPStreamMasqueradingAsCUDA()));
checkCUDA(cudaMemsetAsync(result.buffer->getPtr(), 1, result.buffer->getSize(), getCurrentCUDAStream()));
return result;
}
static Tensor
......@@ -523,18 +522,18 @@ public:
Tensor result = allocate(this->shape.dataExtent, this->scalarType, device);
result.copy_(*this);
// lockBuffer(this->buffer, getCurrentHIPStreamMasqueradingAsCUDA());
// lockBuffer(result.buffer, getCurrentHIPStreamMasqueradingAsCUDA());
// checkCUDA(hipMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), hipMemcpyDefault,
// getCurrentHIPStreamMasqueradingAsCUDA())); if (this->device().type == Device::CPU && device.type == Device::CUDA) {
// checkCUDA(hipMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// hipMemcpyHostToDevice, getCurrentHIPStreamMasqueradingAsCUDA()));
// lockBuffer(this->buffer, getCurrentCUDAStream());
// lockBuffer(result.buffer, getCurrentCUDAStream());
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(), cudaMemcpyDefault,
// getCurrentCUDAStream())); if (this->device().type == Device::CPU && device.type == Device::CUDA) {
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyHostToDevice, getCurrentCUDAStream()));
// } else if (this->device().type == Device::CUDA && device.type == Device::CPU) {
// checkCUDA(hipMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// hipMemcpyDeviceToHost, getCurrentHIPStreamMasqueradingAsCUDA()));
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyDeviceToHost, getCurrentCUDAStream()));
// } else {
// checkCUDA(hipMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// hipMemcpyDefault, getCurrentHIPStreamMasqueradingAsCUDA()));
// checkCUDA(cudaMemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// cudaMemcpyDefault, getCurrentCUDAStream()));
// }
return result;
}
......@@ -549,38 +548,38 @@ public:
// auto shapeOut = this->shape;
// shapeOut[dim] = upper_bound - lower_bound;
// assert(dst.shape.data == shapeOut.data);
// checkCUDA(hipMemcpy2DAsync(
// checkCUDA(cudaMemcpy2DAsync(
// dst.
// ));
// }
private:
static hipMemcpyKind getCopyKind(Device dst, Device src) {
static cudaMemcpyKind getCopyKind(Device dst, Device src) {
if (src.type == Device::CPU && dst.type == Device::CUDA) {
return hipMemcpyHostToDevice;
return cudaMemcpyHostToDevice;
}
if (src.type == Device::CUDA && dst.type == Device::CPU) {
return hipMemcpyDeviceToHost;
return cudaMemcpyDeviceToHost;
}
if (src.type == Device::CUDA && dst.type == Device::CUDA) {
return hipMemcpyDeviceToDevice;
return cudaMemcpyDeviceToDevice;
}
if (src.type == Device::CPU && dst.type == Device::CPU) {
return hipMemcpyHostToHost;
return cudaMemcpyHostToHost;
}
return hipMemcpyDefault;
return cudaMemcpyDefault;
}
// static bool isAsyncBuffer(Buffer *buffer) {
// return dynamic_cast<BufferCUDA *>(buffer);
// }
static inline std::map<hipStream_t, std::set<std::shared_ptr<Buffer>>> lockedBuffers;
static inline std::map<cudaStream_t, std::set<std::shared_ptr<Buffer>>> lockedBuffers;
public:
// before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU
// completes
static void lockBuffer(std::shared_ptr<Buffer> buffer, hipStream_t stream) {
static void lockBuffer(std::shared_ptr<Buffer> buffer, cudaStream_t stream) {
if (!buffer->isAsyncBuffer()) {
lockedBuffers[stream].insert(buffer);
}
......@@ -590,16 +589,16 @@ public:
static void unlockBuffers() {
lockedBuffers.clear();
}
static void unlockBuffers(hipStream_t stream) {
static void unlockBuffers(cudaStream_t stream) {
lockedBuffers[stream].clear();
}
static void synchronizeDevice() {
checkCUDA(hipDeviceSynchronize());
checkCUDA(cudaDeviceSynchronize());
unlockBuffers();
}
static void synchronizeStream(hipStream_t stream) {
checkCUDA(hipStreamSynchronize(stream));
static void synchronizeStream(cudaStream_t stream) {
checkCUDA(cudaStreamSynchronize(stream));
unlockBuffers(stream);
}
};
......
......@@ -19,47 +19,47 @@
#include <optional>
#include <chrono>
#include <functional>
#include <hip/hip_runtime_api.h>
#include <hipblas/hipblas.h>
#include <cuda_runtime_api.h>
#include <cublas_v2.h>
#include <spdlog/spdlog.h>
class CUDAError : public std::runtime_error {
public:
CUDAError(hipError_t errorCode, std::source_location location)
CUDAError(cudaError_t errorCode, std::source_location location)
: std::runtime_error(format(errorCode, location)), errorCode(errorCode), location(location) {}
public:
const hipError_t errorCode;
const cudaError_t errorCode;
const std::source_location location;
private:
static std::string format(hipError_t errorCode, std::source_location location) {
static std::string format(cudaError_t errorCode, std::source_location location) {
return spdlog::fmt_lib::format(
"CUDA error: {} (at {}:{})", hipGetErrorString(errorCode), location.file_name(), location.line());
"CUDA error: {} (at {}:{})", cudaGetErrorString(errorCode), location.file_name(), location.line());
}
};
inline hipError_t checkCUDA(hipError_t retValue,
inline cudaError_t checkCUDA(cudaError_t retValue,
const std::source_location location = std::source_location::current()) {
if (retValue != hipSuccess) {
(void)hipGetLastError();
if (retValue != cudaSuccess) {
(void)cudaGetLastError();
throw CUDAError(retValue, location);
}
return retValue;
}
inline hipblasStatus_t checkCUBLAS(hipblasStatus_t retValue,
inline cublasStatus_t checkCUBLAS(cublasStatus_t retValue,
const std::source_location location = std::source_location::current()) {
if (retValue != HIPBLAS_STATUS_SUCCESS) {
if (retValue != CUBLAS_STATUS_SUCCESS) {
throw std::runtime_error(spdlog::fmt_lib::format(
"CUBLAS error: {} (at {}:{})", rocblas_status_to_string(retValue), location.file_name(), location.line()));
"CUBLAS error: {} (at {}:{})", cublasGetStatusString(retValue), location.file_name(), location.line()));
}
return retValue;
}
inline thread_local std::stack<hipStream_t> stackCUDAStreams;
inline thread_local std::stack<cudaStream_t> stackCUDAStreams;
inline hipStream_t getCurrentHIPStreamMasqueradingAsCUDA() {
inline cudaStream_t getCurrentCUDAStream() {
if (stackCUDAStreams.empty()) {
return 0;
}
......@@ -67,9 +67,9 @@ inline hipStream_t getCurrentHIPStreamMasqueradingAsCUDA() {
}
struct CUDAStreamContext {
hipStream_t stream;
cudaStream_t stream;
CUDAStreamContext(hipStream_t stream) : stream(stream) {
CUDAStreamContext(cudaStream_t stream) : stream(stream) {
stackCUDAStreams.push(stream);
}
CUDAStreamContext(const CUDAStreamContext &) = delete;
......@@ -82,30 +82,30 @@ struct CUDAStreamContext {
};
struct CUDAStreamWrapper {
hipStream_t stream;
cudaStream_t stream;
CUDAStreamWrapper() {
checkCUDA(hipStreamCreate(&stream));
checkCUDA(cudaStreamCreate(&stream));
}
CUDAStreamWrapper(const CUDAStreamWrapper &) = delete;
CUDAStreamWrapper(CUDAStreamWrapper &&) = delete;
~CUDAStreamWrapper() {
checkCUDA(hipStreamDestroy(stream));
checkCUDA(cudaStreamDestroy(stream));
}
};
struct CUDAEventWrapper {
hipEvent_t event;
cudaEvent_t event;
CUDAEventWrapper(unsigned int flags = hipEventDefault) {
checkCUDA(hipEventCreateWithFlags(&event, flags));
CUDAEventWrapper(unsigned int flags = cudaEventDefault) {
checkCUDA(cudaEventCreateWithFlags(&event, flags));
}
CUDAEventWrapper(const CUDAEventWrapper &) = delete;
CUDAEventWrapper(CUDAEventWrapper &&) = delete;
~CUDAEventWrapper() {
checkCUDA(hipEventDestroy(event));
checkCUDA(cudaEventDestroy(event));
}
};
......@@ -162,7 +162,7 @@ public:
static int getDevice() {
int idx = -1;
if (cacheDisabled() || currentDeviceCache < 0) {
checkCUDA(hipGetDevice(&idx));
checkCUDA(cudaGetDevice(&idx));
} else {
idx = currentDeviceCache;
}
......@@ -177,7 +177,7 @@ private:
if (!cacheDisabled() && currentDeviceCache == idx) {
return;
}
checkCUDA(hipSetDevice(idx));
checkCUDA(cudaSetDevice(idx));
currentDeviceCache = cacheDisabled() ? -1 : idx;
}
......@@ -190,13 +190,13 @@ private:
}
};
inline hipDeviceProp_t *getCurrentDeviceProperties() {
static thread_local std::map<int, hipDeviceProp_t> props;
inline cudaDeviceProp *getCurrentDeviceProperties() {
static thread_local std::map<int, cudaDeviceProp> props;
int deviceId = CUDADeviceContext::getDevice();
if (!props.contains(deviceId)) {
hipDeviceProp_t prop;
checkCUDA(hipGetDeviceProperties(&prop, deviceId));
cudaDeviceProp prop;
checkCUDA(cudaGetDeviceProperties(&prop, deviceId));
props[deviceId] = prop;
}
return &props.at(deviceId);
......@@ -217,16 +217,16 @@ constexpr int log2Up(T value) {
}
struct CUBLASWrapper {
hipblasHandle_t handle = nullptr;
cublasHandle_t handle = nullptr;
CUBLASWrapper() {
checkCUBLAS(hipblasCreate(&handle));
checkCUBLAS(cublasCreate(&handle));
}
CUBLASWrapper(CUBLASWrapper &&) = delete;
CUBLASWrapper(const CUBLASWrapper &&) = delete;
~CUBLASWrapper() {
if (handle) {
checkCUBLAS(hipblasDestroy(handle));
checkCUBLAS(cublasDestroy(handle));
}
}
};
......
#include "torch.h"
#include <ATen/hip/HIPContext.h>
#include <ATen/cuda/CUDAContext.h>
using spdlog::fmt_lib::format;
......@@ -37,7 +37,7 @@ Tensor from_torch(at::Tensor input) {
result.scalarType = mapType.at(input.scalar_type());
result.buffer = std::make_shared<BufferTorchTensor>(std::move(input));
Tensor::lockBuffer(result.buffer, getCurrentHIPStreamMasqueradingAsCUDA());
Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
return result;
}
......@@ -76,10 +76,10 @@ at::Tensor to_torch(Tensor input) {
}
TorchOpContext::TorchOpContext() {
stackCUDAStreams.push(at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream());
stackCUDAStreams.push(at::cuda::getCurrentCUDAStream().stream());
}
TorchOpContext::~TorchOpContext() {
assert(stackCUDAStreams.top() == at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream());
assert(stackCUDAStreams.top() == at::cuda::getCurrentCUDAStream().stream());
stackCUDAStreams.pop();
}
#include "hip/hip_runtime.h"
#include "activation_kernels_impl.cuh"
#include "activation_kernels.h"
#include "dispatch_utils.h"
......@@ -9,10 +8,10 @@
int num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const hipStream_t stream = getCurrentHIPStreamMasqueradingAsCUDA(); \
const cudaStream_t stream = getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
hipLaunchKernelGGL(( vllm::activation_kernel<scalar_t, KERNEL<scalar_t>>) \
, dim3(grid), dim3(block), 0, stream, out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
});
void silu_and_mul(Tensor &out, // [..., d]
......@@ -22,14 +21,14 @@ void silu_and_mul(Tensor &out, // [..., d]
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const hipStream_t stream = getCurrentHIPStreamMasqueradingAsCUDA();
const cudaStream_t stream = getCurrentCUDAStream();
// dispatchFloat(input.scalar_type(), [&]<typename scalar_t>() {
// vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
// out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
// });
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_kernel", [&] {
hipLaunchKernelGGL(( vllm::silu_and_mul_kernel<scalar_t>)
, dim3(grid), dim3(block), 0, stream, out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
vllm::silu_and_mul_kernel<scalar_t>
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
});
}
......@@ -42,8 +41,8 @@ void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const hipStream_t stream = getCurrentHIPStreamMasqueradingAsCUDA();
hipLaunchKernelGGL(( vllm::dequant_silu_and_mul_quant_kernel<float, false>), dim3(grid), dim3(block), 0, stream,
const cudaStream_t stream = getCurrentCUDAStream();
vllm::dequant_silu_and_mul_quant_kernel<float, false><<<grid, block, 0, stream>>>(
out.data_ptr<int8_t>(), input.data_ptr<int32_t>(), d, scale_gate, scale_up, scale_out);
}
......@@ -58,8 +57,8 @@ void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
int d = input.size(-1) / 2;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const hipStream_t stream = getCurrentHIPStreamMasqueradingAsCUDA();
hipLaunchKernelGGL(( vllm::dequant_silu_and_mul_quant_kernel<float *, true>), dim3(grid), dim3(block), 0, stream, out.data_ptr<int8_t>(),
const cudaStream_t stream = getCurrentCUDAStream();
vllm::dequant_silu_and_mul_quant_kernel<float *, true><<<grid, block, 0, stream>>>(out.data_ptr<int8_t>(),
input.data_ptr<int32_t>(),
d,
scale_gate,
......
#include "hip/hip_runtime.h"
#include "utils.cuh"
#include "reduction_utils.cuh"
......
......@@ -11,7 +11,7 @@ https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutl
*/
#pragma once
#include <hip/hip_fp16.h>
#include <cuda_fp16.h>
#include <cstdint>
__forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result) {
......@@ -75,14 +75,14 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uin
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
}
__forceinline__ __device__ void dequantize_s4_to_fp16x2(__hip_bfloat162 const &source, uint4 *result) {
__forceinline__ __device__ void dequantize_s4_to_fp16x2(__nv_bfloat162 const &source, uint4 *result) {
// dequantize_s4_to_fp16x2(reinterpret_cast<const half2 &>(source), result);
// *reinterpret_cast<__hip_bfloat162 *>(&result->x) = cuda_cast<__hip_bfloat162>(*reinterpret_cast<half2
// *>(&result->x)); *reinterpret_cast<__hip_bfloat162 *>(&result->y) =
// cuda_cast<__hip_bfloat162>(*reinterpret_cast<half2 *>(&result->y)); *reinterpret_cast<__hip_bfloat162
// *>(&result->z) = cuda_cast<__hip_bfloat162>(*reinterpret_cast<half2 *>(&result->z));
// *reinterpret_cast<__hip_bfloat162 *>(&result->w) = cuda_cast<__hip_bfloat162>(*reinterpret_cast<half2
// *reinterpret_cast<__nv_bfloat162 *>(&result->x) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2
// *>(&result->x)); *reinterpret_cast<__nv_bfloat162 *>(&result->y) =
// cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->y)); *reinterpret_cast<__nv_bfloat162
// *>(&result->z) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2 *>(&result->z));
// *reinterpret_cast<__nv_bfloat162 *>(&result->w) = cuda_cast<__nv_bfloat162>(*reinterpret_cast<half2
// *>(&result->w));
// return;
......
#include "hip/hip_runtime.h"
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "semaphore.h"
#include "gemm_awq.h"
// #include "../../../nunchaku/csrc/quantization/dequantize.cuh"
......@@ -47,8 +46,8 @@
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK); \
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
auto kernel_func = gemm_w4a16_T1<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G, SPLITK>; \
hipFuncSetAttribute(kernel_func, hipFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
hipLaunchKernelGGL(( kernel_func), dim3(num_blocks), dim3(threads_per_block), kSmemByteSize, 0, \
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>( \
in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels);
template<int N>
......@@ -91,8 +90,8 @@ __inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr) {
template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __hip_bfloat16>::value,
"ldmatrix_m8n8_x4_b16 supports only half or __hip_bfloat16 types.");
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_b16 supports only half or __nv_bfloat16 types.");
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
......@@ -104,8 +103,8 @@ __inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, u
template<typename f16_t>
__inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(f16_t *shared_warp, int ax0_0, uint32_t addr) {
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __hip_bfloat16>::value,
"ldmatrix_m8n8_x4_trans_b16 supports only half or __hip_bfloat16 types.");
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __nv_bfloat16>::value,
"ldmatrix_m8n8_x4_trans_b16 supports only half or __nv_bfloat16 types.");
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned *)(shared_warp + (ax0_0 * 8)))[0]),
......@@ -150,7 +149,7 @@ __device__ __inline__ void mma_m16n8k16<half>(float *C_warp, half *A_shared_warp
template<>
__device__ __inline__ void
mma_m16n8k16<__hip_bfloat16>(float *C_warp, __hip_bfloat16 *A_shared_warp, __hip_bfloat16 *B_shared_warp) {
mma_m16n8k16<__nv_bfloat16>(float *C_warp, __nv_bfloat16 *A_shared_warp, __nv_bfloat16 *B_shared_warp) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
......@@ -379,7 +378,7 @@ __global__ void gemm_w4a16_T1(f16_t *__restrict__ A,
int M,
int N,
int K) {
#if defined(__DTK_ARCH__) && __DTK_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch();
return;
#endif
......@@ -945,7 +944,7 @@ __global__ void gemm_w4a16_T2(f16_t *__restrict__ A,
int M,
int N,
int K) {
#if defined(__DTK_ARCH__) && __DTK_ARCH__ < 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
trap_unsupported_arch();
return;
#endif
......@@ -1278,12 +1277,12 @@ Tensor awq_gemm_forward_cuda(Tensor _in_feats, Tensor _kernel, Tensor _scales, T
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
hipFuncSetAttribute(kernel_func, hipFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
hipLaunchKernelGGL(( kernel_func), dim3(num_blocks), dim3(threads_per_block), kSmemByteSize, 0,
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
}
} else if (_in_feats.scalar_type() == Tensor::BF16) {
using f16_t = __hip_bfloat16;
using f16_t = __nv_bfloat16;
auto in_feats = reinterpret_cast<f16_t *>(_in_feats.data_ptr());
auto kernel = reinterpret_cast<f16_t *>(_kernel.data_ptr<int16_t>());
......@@ -1358,8 +1357,8 @@ Tensor awq_gemm_forward_cuda(Tensor _in_feats, Tensor _kernel, Tensor _scales, T
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1);
dim3 threads_per_block(WARP_SIZE, NUM_WARPS);
auto kernel_func = gemm_w4a16_T2<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>;
hipFuncSetAttribute(kernel_func, hipFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
hipLaunchKernelGGL(( kernel_func), dim3(num_blocks), dim3(threads_per_block), kSmemByteSize, 0,
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize>>>(
in_feats, kernel, scales, zeros, out_feats, num_in_feats, num_out_channels, num_in_channels);
}
} else {
......
#include "hip/hip_runtime.h"
/*
* Modified from NVIDIA
* [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
......@@ -31,8 +30,8 @@
#include "../utils.cuh"
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <stdio.h>
#include "dequantize.cuh"
......@@ -81,7 +80,7 @@ __device__ __forceinline__ packed_as<half, 2>::type half2half2<half>(half x) {
}
template<>
__device__ __forceinline__ packed_as<__hip_bfloat16, 2>::type half2half2<__hip_bfloat16>(__hip_bfloat16 x) {
__device__ __forceinline__ packed_as<__nv_bfloat16, 2>::type half2half2<__nv_bfloat16>(__nv_bfloat16 x) {
return __bfloat162bfloat162(x);
}
......@@ -94,7 +93,7 @@ __device__ __forceinline__ float2 half22float2<half2>(half2 val) {
}
template<>
__device__ __forceinline__ float2 half22float2<__hip_bfloat162>(__hip_bfloat162 val) {
__device__ __forceinline__ float2 half22float2<__nv_bfloat162>(__nv_bfloat162 val) {
return __bfloat1622float2(val);
}
......@@ -107,8 +106,8 @@ __global__ void gemv_kernel(const half_t *inputs,
const int IC,
const int OC) {
#if defined(__DTK_ARCH__) && __DTK_ARCH__ < 800
if constexpr (std::is_same_v<half_t, __hip_bfloat16>) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
if constexpr (std::is_same_v<half_t, __nv_bfloat16>) {
trap_unsupported_arch();
return;
}
......@@ -283,10 +282,10 @@ Tensor gemv_awq(
return;
}
if constexpr (M > 0) {
hipLaunchKernelGGL(( gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE>)
, dim3(num_blocks), dim3(num_threads), 0, getCurrentHIPStreamMasqueradingAsCUDA(),
gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE>
<<<num_blocks, num_threads, 0, getCurrentCUDAStream()>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
checkCUDA(hipGetLastError());
checkCUDA(cudaGetLastError());
}
});
......
#include "hip/hip_runtime.h"
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
......@@ -56,7 +55,7 @@ public:
/// Permit fetching the synchronization mechanism early
__device__ void fetch() {
if (wait_thread) {
#if defined(__DTK_ARCH__) && __DTK_ARCH__ >= 700
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#else
asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
......@@ -83,7 +82,7 @@ public:
__syncthreads();
if (wait_thread) {
#if defined(__DTK_ARCH__) && __DTK_ARCH__ >= 700
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#else
asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
......
......@@ -2,13 +2,13 @@
#include "common.h"
#include "Tensor.h"
#include <hip/hip_fp16.h>
#include <cuda_fp16.h>
template<typename F>
inline auto dispatchFloat(Tensor::ScalarType scalarType, F &&func) {
switch (scalarType) {
case Tensor::BF16:
return func.template operator()<__hip_bfloat16>();
return func.template operator()<__nv_bfloat16>();
case Tensor::FP16:
return func.template operator()<half>();
case Tensor::FP32:
......@@ -23,7 +23,7 @@ template<typename F>
inline auto dispatchFloat16(Tensor::ScalarType scalarType, F &&func) {
switch (scalarType) {
case Tensor::BF16:
return func.template operator()<__hip_bfloat16>();
return func.template operator()<__nv_bfloat16>();
case Tensor::FP16:
return func.template operator()<half>();
default:
......@@ -36,7 +36,7 @@ template<typename F>
inline auto dispatch(Tensor::ScalarType scalarType, F &&func) {
switch (scalarType) {
case Tensor::BF16:
return func.template operator()<__hip_bfloat16>();
return func.template operator()<__nv_bfloat16>();
case Tensor::FP16:
return func.template operator()<half>();
case Tensor::FP32:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment