Commit ac6a172d authored by zhanggzh's avatar zhanggzh
Browse files

change versio code

parent 60fb332b
r"""
The fmoe package contains MoE Layers only.
"""
__version__ = "1.1.0"
from .version import __dcu_version__
from .layers import FMoE
from .linear import FMoELinear
from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel
__version__ = '1.2.0'
__version__ = '1.1.0'
__dcu_version__ = '24041'
try:
import torch
except ImportError:
pass
import subprocess
from pathlib import Path
import os
UNKNOWN = "Unknown"
def sha_value(moe_root):
try:
return (
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=moe_root)
.decode("ascii")
.strip()
)
except Exception:
return UNKNOWN
def abi_value():
try:
return (
subprocess.check_output("echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI", shell=True)
.decode('ascii')
.strip()[-1]
)
except Exception:
return UNKNOWN
def dtk_version_value():
try:
dtk_path=os.getenv('ROCM_PATH')
dtk_version_path = os.path.join(dtk_path, '.info', "version-dev")
with open(dtk_version_path, 'r',encoding='utf-8') as file:
lines = file.readlines()
dtk_version="dtk"+lines[0][:-2].replace(".", "")
return dtk_version
except Exception:
return UNKNOWN
def torch_version_value():
try:
torch_version = "torch" + (torch.__version__).split('.')[0]+ "." +torch.__version__.split('.')[1]
return torch_version
except Exception:
return UNKNOWN
def moe_whl_name():
try:
moe_root = Path(__file__).parent
sha = "das1.1.git" + sha_value(moe_root)[0:7]
abi = "abi" + abi_value()
dtk_version = dtk_version_value()
try:
import torch
torch_version = torch_version_value()
except ImportError:
torch_version = "null"
whl_name = "+" + sha + "." + abi + "." + dtk_version + "." + torch_version
return whl_name
except Exception:
return UNKNOWN
def dcu_version():
try:
moe_version = '1.1.0'
dcu_version = moe_version + moe_whl_name()
return dcu_version
except Exception:
return UNKNOWN
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "../hip/stream_manager.h"
#include "../hip/utils/fmoe_utils.h"
#include <hip/hip_runtime.h>
__global__
void limit_by_capacity_kernel(const long* ec, int* cap, long* eca,
const long n_expert, const long n_worker) {
int eid = blockIdx.y;
int wid = blockIdx.x * blockDim.x + threadIdx.x;
if (wid < n_worker) {
int proposal = ec[wid * n_expert + eid];
int cap_left = atomicSub(cap + eid, proposal);
if (cap_left >= proposal) {
eca[wid * n_expert + eid] = proposal;
} else if (cap_left >= 0) {
eca[wid * n_expert + eid] = cap_left;
} else {
eca[wid * n_expert + eid] = 0;
}
}
}
void fmoe_cuda_limit_by_capacity_impl(const long* ec, int* cap,
long* eca, const long n_expert, const long n_worker,
CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(n_worker, 1024), n_expert);
dim3 block_dim(1024);
hipLaunchKernelGGL(( limit_by_capacity_kernel), dim3(grid_dim), dim3(block_dim), 0, smgr->torchStream(),
ec, cap, eca, n_expert, n_worker);
}
__global__
void prune_gate_by_capacity_kernel(const long* gate_idx, long* new_gate_idx,
int* ec,
const long batch_size, const long n_expert, const long n_worker) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < batch_size) {
int orig_cap = atomicSub(ec + gate_idx[i], 1);
if (orig_cap <= 0) {
new_gate_idx[i] = -1;
} else {
new_gate_idx[i] = gate_idx[i];
}
}
}
void fmoe_cuda_prune_gate_by_capacity_impl(long* gate_idx, long* new_gate_idx,
int* ec,
const long batch_size, const long n_expert, const long n_worker,
CudaStreamManager* smgr) {
dim3 grid_dim(CEIL(batch_size, 1024));
dim3 block_dim(1024);
hipLaunchKernelGGL(( prune_gate_by_capacity_kernel), dim3(grid_dim), dim3(block_dim), 0, smgr->torchStream(),
gate_idx, new_gate_idx, ec, batch_size, n_expert, n_worker
);
}
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include <cstdio>
#include "../hip/balancing.cuh"
#include "../hip/global_exchange.h"
#include <torch/extension.h>
/*
* note that due to limit of cuda atomic operator, capacity should be int32
*/
torch::Tensor _limit_by_capacity(
torch::Tensor expert_count, torch::Tensor capacity,
long n_expert, long n_worker) {
CHECK_INPUT(expert_count);
CHECK_INPUT(capacity);
auto expert_count_ack = torch::empty_like(expert_count);
auto smgr = getCudaStreamManager(expert_count.device().index());
fmoe_cuda_limit_by_capacity_impl(
expert_count.data_ptr<long>(),
capacity.data_ptr<int>(),
expert_count_ack.data_ptr<long>(),
n_expert, n_worker, smgr);
return expert_count_ack;
}
torch::Tensor _prune_gate_by_capacity(
torch::Tensor gate_idx, torch::Tensor expert_count,
long n_expert, long n_worker) {
auto smgr = getCudaStreamManager(expert_count.device().index());
auto batch_size = gate_idx.numel();
auto opt = torch::TensorOptions()
.dtype(gate_idx.dtype())
.device(gate_idx.device());
auto new_gate_idx = torch::empty(gate_idx.sizes(), opt);
fmoe_cuda_prune_gate_by_capacity_impl(
gate_idx.data_ptr<long>(),
new_gate_idx.data_ptr<long>(),
expert_count.data_ptr<int>(),
batch_size, n_expert, n_worker, smgr);
return new_gate_idx;
}
template<class T>
T* _cudamalloc(size_t sz) {
T* dptr;
hipMalloc(&dptr, sz * sizeof(T));
return dptr;
}
template<class T>
T* _h2d(const T* hptr, T* dptr, size_t sz) {
hipMemcpy(dptr, hptr, sz * sizeof(T), hipMemcpyHostToDevice);
return dptr;
}
template<class T>
T* _h2d(T* hptr, size_t sz) {
T* dptr = _cudamalloc<T>(sz);
return _h2d(hptr, dptr, sz);
}
template<class T>
T* _d2h(const T* dptr, T* hptr, size_t sz) {
hipMemcpy(hptr, dptr, sz * sizeof(T), hipMemcpyDeviceToHost);
return hptr;
}
template<class T>
T* _d2h(const T* dptr, size_t sz) {
T* hptr = new T[sz];
return _d2h(dptr, hptr, sz);
}
#ifdef FMOE_USE_NCCL
#include <rccl/rccl.h>
#define UPDATE_COUNTERS(__count__) { \
if (i == rank) { \
lec[j] += (__count__); \
} \
if (j == rank) { \
gec[i] += (__count__); \
cap -= (__count__); \
} \
}
std::vector<torch::Tensor> _swipe_once(
torch::Tensor gate_idx, torch::Tensor capacity,
long n_expert, long n_worker, long bias) {
auto device_idx = gate_idx.device().index();
auto smgr = getCudaStreamManager(device_idx);
int rank;
ncclCommUserRank(smgr->ncclcomm, &rank);
hipSetDevice(device_idx);
auto capacity_new = capacity.clone();
auto cap = capacity_new.item<long>();
long batch_size = gate_idx.size(0);
auto gate_idx_cpu = gate_idx.cpu();
long* gidx = gate_idx_cpu.data_ptr<long>();
/* Local count and exchange */
long *lec = new long[n_worker];
memset(lec, 0, n_worker * sizeof(long));
for (long i = 0; i < batch_size; ++i) {
++lec[gidx[i] / n_expert];
}
long *d_lec = _h2d(lec, n_worker), *d_gec = _cudamalloc<long>(n_worker);
fmoe_cuda_expert_exchange_impl(d_lec, d_gec, 1, n_worker, smgr);
smgr->syncTorch();
long *gec = _d2h(d_gec, n_worker);
/* Limit number of incoming samples */
long *drop_count = new long[n_worker];
memset(drop_count, 0, n_worker * sizeof(long));
for (long i = 0; i < n_worker; ++i) {
if (cap >= gec[i]) {
drop_count[i] = 0;
cap -= gec[i];
} else {
drop_count[i] = gec[i] - cap;
gec[i] = cap;
cap = 0;
}
}
/* Send limit information back */
_h2d(gec, d_gec, n_worker);
fmoe_cuda_expert_exchange_impl(d_gec, d_lec, 1, n_worker, smgr);
smgr->syncTorch();
_d2h(d_lec, lec, n_worker);
auto d_dropcount = _h2d(drop_count, n_worker);
ncclAllReduce(d_dropcount, d_dropcount, n_worker, ncclInt64, ncclSum,
smgr->ncclcomm, smgr->torchStream());
smgr->syncTorch();
_d2h(d_dropcount, drop_count, n_worker);
auto d_gcap = _cudamalloc<long>(n_worker);
_h2d(&cap, d_gcap + rank, 1);
ncclAllGather(d_gcap + rank, d_gcap, 1, ncclInt64,
smgr->ncclcomm, smgr->torchStream());
smgr->syncTorch();
auto gcap = _d2h(d_gcap, n_worker);
/* Re-assign and update counters */
for (long i = 0, j = 0; i < n_worker; ++i) {
while (drop_count[i] > 0) {
if (drop_count[i] > gcap[j]) {
drop_count[i] -= gcap[j];
UPDATE_COUNTERS(gcap[j]);
++j;
} else {
gcap[j] -= drop_count[i];
UPDATE_COUNTERS(drop_count[i]);
break;
}
}
}
for (long i = 0; i < batch_size; ++i) {
auto widx = gidx[i] / n_expert;
if (lec[widx] > 0) {
--lec[widx];
} else {
gidx[i] = -1;
}
}
for (long i = 0, k = 0; i < batch_size; ++i) {
if (gidx[i] != -1) {
continue;
}
for (; lec[k] == 0; ++k);
--lec[k];
gidx[i] = k * n_expert + bias;
}
*capacity_new.data_ptr<long>() = cap;
delete [] drop_count;
delete [] lec;
delete [] gec;
delete [] gcap;
hipFree(d_dropcount);
hipFree(d_lec);
hipFree(d_gec);
hipFree(d_gcap);
return {gate_idx_cpu, capacity_new};
}
#undef UPDATE_COUNTERS
#endif
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifdef FMOE_USE_NCCL
#include <cstdlib>
#include <vector>
#include <torch/extension.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include "../../hip/fastermoe/smart_schedule.h"
#include "../../hip/fastermoe/status.h"
long pipeline_gran = -1;
int smart_sch_enabled = 0;
int isSmartSchEnabled() {
return smart_sch_enabled;
}
void setSmartSchEnabled(int s) {
smart_sch_enabled = s;
}
inline ncclDataType_t getNcclDataType(at::ScalarType t) {
switch (t) {
case at::kChar: return ncclInt8;
case at::kByte: return ncclUint8;
case at::kFloat: return ncclFloat;
case at::kDouble: return ncclDouble;
case at::kInt: return ncclInt32;
case at::kLong: return ncclInt64;
case at::kHalf: return ncclHalf;
case at::kBool: return ncclUint8;
#if defined(ENABLE_NCCL_BF16_DATATYPE)
case at::kBFloat16: return ncclBfloat16;
#endif
default: return ncclChar;
}
}
void _reduce_grad(
torch::Tensor t,
long root,
long expert_size) {
auto smgr = getCudaStreamManager(t.device().index());
hipEvent_t evt_stash;
hipEventCreate(&evt_stash);
hipEventRecord(evt_stash, smgr->torchStream());
FMOE_SWE(smgr->stream(0), evt_stash);
hipEventDestroy(evt_stash);
auto dtype = getNcclDataType(t.scalar_type());
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
t.scalar_type(), "fmoe_cuda_reduce_grad", ([&] {
void* buf = (void*)t.data_ptr<scalar_t>();
NCCL_SAFE_CALL(ncclReduce(buf, buf, expert_size,
dtype,
ncclSum, root,
smgr->ncclcomm, smgr->stream(0)));
})
);
}
std::vector<torch::Tensor> _smart_sch_forward(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
torch::Tensor stored_models,
long global_batch_size,
long expert_size,
long n_workers,
py::function forward_fn,
py::function get_param_fn,
py::function stash_fn,
py::function pop_fn) {
if (pipeline_gran == -1) {
char* p = getenv("FMOE_FASTER_GROUP_SIZE");
if (p) {
pipeline_gran = atoi(p);
} else {
pipeline_gran = 4;
}
setSmartSchEnabled(1);
}
auto smgr = getCudaStreamManager(input_buf.device().index());
int rank;
NCCL_SAFE_CALL(ncclCommUserRank(smgr->ncclcomm, &rank));
const auto num_expert = local_expert_count.size(0) / n_workers;
const auto d_model = input_buf.size(1);
// TODO: maybe empty is faster
auto global_input_buf = input_buf.new_zeros({global_batch_size, d_model});
auto global_output_buf = input_buf.new_zeros({global_batch_size, d_model});
auto output_buf = input_buf.new_zeros({input_buf.size(0), d_model});
std::vector<torch::Tensor> params;
auto stored_models_ = stored_models.data_ptr<bool>();
for (long i = 0; i < num_expert * n_workers; ++i) {
if (stored_models_[i]) {
torch::Tensor t = input_buf.new_empty({expert_size});
if (i / num_expert == rank) {
get_param_fn(t, i % num_expert);
}
params.push_back(t);
}
}
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
input_buf.scalar_type(), "fmoe_cuda_smart_sch_forward", ([&] {
fmoe_cuda_fused_forward_impl(
forward_fn,
stash_fn,
pop_fn,
input_buf.device(),
params,
input_buf.data_ptr<scalar_t>(),
global_input_buf.data_ptr<scalar_t>(),
global_output_buf.data_ptr<scalar_t>(),
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
stored_models.data_ptr<bool>(),
d_model, num_expert, rank, n_workers, expert_size,
pipeline_gran, smgr);
}));
return {output_buf, global_input_buf};
}
torch::Tensor _smart_sch_backward(
torch::Tensor grad_out,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
torch::Tensor stored_models,
long buf_batch_size,
long global_batch_size,
long n_workers,
py::function backward_fn,
py::function stash_fn,
py::function pop_fn,
py::function collect_fn,
py::function set_grad_fn) {
const auto num_expert = local_expert_count.size(0) / n_workers;
auto smgr = getCudaStreamManager(grad_out.device().index());
int rank;
ncclCommUserRank(smgr->ncclcomm, &rank);
const auto d_model = grad_out.size(1);
auto global_grad_out = grad_out.new_zeros({global_batch_size, d_model});
auto global_grad_in = grad_out.new_zeros({global_batch_size, d_model});
auto grad_in = grad_out.new_zeros({buf_batch_size, d_model});
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_out.scalar_type(),
"fmoe_cuda_smartsch_backward", ([&] {
fmoe_cuda_fused_backward_impl(
backward_fn,
stash_fn,
pop_fn,
collect_fn,
set_grad_fn,
grad_out.device(),
grad_out.data_ptr<scalar_t>(),
global_grad_out.data_ptr<scalar_t>(),
global_grad_in.data_ptr<scalar_t>(),
grad_in.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
stored_models.data_ptr<bool>(),
d_model, num_expert, rank, n_workers,
pipeline_gran, smgr);
}));
return grad_in;
}
#endif
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef SMART_SCHEDULE_H
#define SMART_SCHEDULE_H
#include <cstdio>
#include <iostream>
#include <vector>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <rccl/rccl.h>
#include "../../hip/stream_manager.h"
#if defined(DTK_VERSION) && (DTK_VERSION < 110010)
#define FMOE_SWE(__s__,__e__) hipStreamWaitEvent(__s__,__e__,0)
#else
#define FMOE_SWE(__s__,__e__) hipStreamWaitEvent(__s__,__e__)
#endif
template<typename scalar_t>
void exchangeWith(
const scalar_t* sendbuf, size_t sendcount, int t_send,
scalar_t* recvbuf, size_t recvcount, int t_recv,
long d_model,
hipStream_t stream, ncclComm_t comm) {
if (sendcount) {
ncclSend(sendbuf, sendcount * d_model * sizeof(scalar_t),
ncclChar, t_send , comm, stream);
}
if (recvcount) {
ncclRecv(recvbuf, recvcount * d_model * sizeof(scalar_t),
ncclChar, t_recv, comm, stream);
}
}
#define GEN_BASE(_step) \
long to_base = (group_rank + _step) % n_groups * pipeline_gran; \
long from_base = (group_rank + n_groups - _step) % n_groups * pipeline_gran;
#define GEN_IDX \
int idx_send = ei + rank_send * num_expert; \
int idx_recv = ei + rank_recv * num_expert; \
int gidx_send = ei * world_size + rank_send; \
int gidx_recv = ei * world_size + rank_recv; \
int idx_self = ei + rank * num_expert;
void computePtrs(long num_expert, long rank, long world_size,
const long* local_expert_count,
const long* global_expert_count,
const bool* stored_models,
int *local_ptr,
int *global_ptr,
int *local_global_ptr) {
local_ptr[0] = global_ptr[0] = local_global_ptr[0] = 0;
for (int i = 0; i < num_expert * world_size; ++i) {
local_ptr[i + 1] = local_ptr[i] + local_expert_count[i];
local_global_ptr[i + 1] = local_global_ptr[i];
// if model fetched, add local tokens
if (stored_models[i]){
local_global_ptr[i + 1] += local_expert_count[i];
}
auto expert_idx = i % num_expert;
auto worker_idx = i / num_expert;
auto gp_idx = expert_idx * world_size + worker_idx;
// if local model wasn't fetched, receive global tokens
if (stored_models[rank * num_expert + expert_idx]) {
global_ptr[gp_idx + 1] = 0;
} else {
global_ptr[gp_idx + 1] = global_expert_count[i];
}
}
global_ptr[0] = 0;
for (int i = 0; i < num_expert * world_size; ++i) {
global_ptr[i + 1] += global_ptr[i];
}
}
template<typename scalar_t>
void computeFn(py::function fn, c10::Device device,
scalar_t* inp_buf, scalar_t* out_buf,
long expert_idx, long store_idx, long offset, long micro_batch_size, long d_model,
CudaStreamManager* smgr) {
if(micro_batch_size == 0) {
return;
}
auto options = torch::TensorOptions()
.dtype(c10::CppTypeToScalarType<scalar_t>::value)
.device(device)
.requires_grad(true);
auto inp = torch::from_blob(inp_buf + offset * d_model,
{micro_batch_size, d_model}, options);
auto oup = torch::from_blob(out_buf + offset * d_model,
{micro_batch_size, d_model}, options);
smgr->use_default = true;
fn(inp, oup, expert_idx, store_idx);
smgr->use_default = false;
}
template<typename scalar_t>
void fmoe_cuda_fused_forward_impl(
py::function forward_fn,
py::function stash_fn,
py::function pop_fn,
c10::Device device,
std::vector<torch::Tensor> params,
scalar_t* input_buf,
scalar_t* global_input_buf,
scalar_t* global_output_buf,
scalar_t* output_buf,
const long* local_expert_count,
const long* global_expert_count,
const bool* stored_models,
long d_model,
long num_expert, long rank, long world_size, long expert_size,
long pipeline_gran, CudaStreamManager* smgr) {
smgr->syncTorch();
int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1];
int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker
computePtrs(num_expert, rank, world_size,
local_expert_count, global_expert_count, stored_models,
local_ptr, global_ptr, local_global_ptr);
if (pipeline_gran > world_size) {
pipeline_gran = world_size;
}
long n_groups = world_size / pipeline_gran;
long group_rank = rank / pipeline_gran;
hipEvent_t *input_ready = new hipEvent_t[n_groups];
hipEvent_t *output_ready = new hipEvent_t[n_groups];
hipEvent_t *output_torch_ready = new hipEvent_t[n_groups];
for (long i = 0; i < n_groups; ++i) {
hipEventCreate(input_ready + i);
hipEventCreate(output_ready + i);
hipEventCreate(output_torch_ready + i);
}
// S_0 ... S_n
for (long step = 0; step < n_groups; ++step) {
for (long ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < pipeline_gran; ++j) {
int rank_send = j + to_base;
int rank_recv = j + from_base;
GEN_IDX;
exchangeWith(input_buf + local_ptr[idx_send] * d_model,
local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
global_input_buf + global_ptr[gidx_recv] * d_model,
global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
d_model, smgr->stream(num_expert), smgr->ncclcomm);
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
hipEventRecord(input_ready[step], smgr->stream(num_expert));
}
// Broadcast shadowed experts
hipEvent_t evt_get, *evt_shadow;
if (params.size() > 0) {
evt_shadow = new hipEvent_t[params.size()];
}
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
if (i / num_expert == rank) {
hipEventCreate(&evt_get);
hipEventRecord(evt_get, smgr->stream(0));
FMOE_SWE(smgr->stream(num_expert), evt_get);
hipEventDestroy(evt_get);
}
NCCL_SAFE_CALL(ncclBcast((void*)params[si].data_ptr<scalar_t>(),
expert_size * sizeof(scalar_t), ncclChar,
i / num_expert, smgr->ncclcomm, smgr->stream(num_expert)));
hipEventCreate(evt_shadow + si);
hipEventRecord(evt_shadow[si], smgr->stream(num_expert));
++si;
}
}
// C_0 ... C_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), input_ready[step]);
FMOE_SWE(smgr->torchStream(), input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base];
long micro_batch_size = global_ptr[ei * world_size +
(from_base + pipeline_gran)] - offset;
computeFn(forward_fn, device,
global_input_buf, global_output_buf,
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
}
hipEventRecord(output_ready[step], smgr->stream(0));
hipEventRecord(output_torch_ready[step], smgr->torchStream());
}
// Compute over shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
FMOE_SWE(smgr->stream(0), evt_shadow[si]);
FMOE_SWE(smgr->torchStream(), evt_shadow[si]);
stash_fn(params[si], si, 0); // always put shadowed expert at first, so expert_idx = 0
long offset = local_ptr[i];
long micro_batch_size = local_expert_count[i];
computeFn(forward_fn, device,
input_buf, output_buf,
0, n_groups * num_expert + si, offset, micro_batch_size, d_model, smgr);
++si;
}
}
pop_fn(0);
// R_0 ... R_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(num_expert), output_ready[step]);
FMOE_SWE(smgr->stream(num_expert), output_torch_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < pipeline_gran; ++j) {
int rank_send = j + from_base;
int rank_recv = j + to_base;
GEN_IDX;
exchangeWith(global_output_buf + global_ptr[gidx_send] * d_model,
global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
output_buf + local_ptr[idx_recv] * d_model,
local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
d_model, smgr->stream(num_expert), smgr->ncclcomm);
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
}
smgr->sync(num_expert + 1);
delete [] local_ptr;
delete [] global_ptr;
delete [] local_global_ptr;
checkCudaErrors(hipGetLastError());
for (long i = 0; i < n_groups; ++i) {
hipEventDestroy(input_ready[i]);
hipEventDestroy(output_ready[i]);
hipEventDestroy(output_torch_ready[i]);
}
for (unsigned i = 0; i < params.size(); ++i) {
hipEventDestroy(evt_shadow[i]);
}
delete [] input_ready;
delete [] output_ready;
delete [] output_torch_ready;
}
template<typename scalar_t>
void fmoe_cuda_fused_backward_impl(
py::function backward_fn,
py::function stash_fn,
py::function pop_fn,
py::function collect_fn,
py::function set_grad_fn,
c10::Device device,
scalar_t* grad_out,
scalar_t* global_grad_out,
scalar_t* global_grad_in,
scalar_t* grad_in,
const long* local_expert_count,
const long* global_expert_count,
const bool* stored_models,
long d_model,
long num_expert, long rank, long world_size,
long pipeline_gran, CudaStreamManager* smgr) {
smgr->syncTorch();
int *local_ptr = new int[num_expert * world_size + 1];
int *global_ptr = new int[num_expert * world_size + 1];
int *local_global_ptr = new int[num_expert * world_size + 1]; // local fetched models tracker
computePtrs(num_expert, rank, world_size,
local_expert_count, global_expert_count, stored_models,
local_ptr, global_ptr, local_global_ptr);
if (pipeline_gran > world_size) {
pipeline_gran = world_size;
}
long n_groups = world_size / pipeline_gran;
long group_rank = rank / pipeline_gran;
hipEvent_t *input_ready = new hipEvent_t[n_groups];
hipEvent_t *output_ready = new hipEvent_t[n_groups];
hipEvent_t *output_torch_ready = new hipEvent_t[n_groups];
for (long i = 0; i < n_groups; ++i) {
hipEventCreate(input_ready + i);
hipEventCreate(output_ready + i);
hipEventCreate(output_torch_ready + i);
}
// S_0 ... S_n
for (long step = 0; step < n_groups; ++step) {
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < pipeline_gran; ++j) {
int rank_send = j + to_base;
int rank_recv = j + from_base;
GEN_IDX;
exchangeWith(grad_out + local_ptr[idx_send] * d_model,
local_expert_count[idx_send] * !stored_models[idx_send], rank_send,
global_grad_out + global_ptr[gidx_recv] * d_model,
global_expert_count[idx_recv] * !stored_models[idx_self], rank_recv,
d_model, smgr->stream(num_expert), smgr->ncclcomm);
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
hipEventRecord(input_ready[step], smgr->stream(num_expert));
}
// Shadowed experts backward and reduce
hipEvent_t *evt_reduce = new hipEvent_t[num_expert];
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
stash_fn(si, 0);
long offset = local_ptr[i];
long micro_batch_size = local_expert_count[i];
computeFn(backward_fn, device,
grad_out, grad_in,
0, n_groups * num_expert + si, offset, micro_batch_size, d_model, smgr);
collect_fn(si, i / num_expert, 0);
if (i / num_expert == rank) {
hipEventCreate(evt_reduce + i % num_expert);
hipEventRecord(evt_reduce[i % num_expert], smgr->stream(0));
}
++si;
}
}
pop_fn(0);
// C_0 ... C_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(0), input_ready[step]);
FMOE_SWE(smgr->torchStream(), input_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
long offset = global_ptr[ei * world_size + from_base];
long micro_batch_size = global_ptr[ei * world_size +
(from_base + pipeline_gran)] - offset;
computeFn(backward_fn, device,
global_grad_out, global_grad_in,
(long) ei, step * num_expert + ei, offset, micro_batch_size, d_model, smgr);
}
hipEventRecord(output_ready[step], smgr->stream(0));
hipEventRecord(output_torch_ready[step], smgr->torchStream());
}
// Collect gradients for shadowed experts
for (long i = 0, si = 0; i < world_size * num_expert; ++i) {
if (stored_models[i]) {
if (i / num_expert == rank) {
FMOE_SWE(smgr->torchStream(), evt_reduce[i % num_expert]);
set_grad_fn(si, i % num_expert);
}
++si;
}
}
// R_0 ... R_n
for (long step = 0; step < n_groups; ++step) {
FMOE_SWE(smgr->stream(num_expert), output_ready[step]);
FMOE_SWE(smgr->stream(num_expert), output_torch_ready[step]);
for (int ei = 0; ei < num_expert; ++ei) {
GEN_BASE(step);
NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < pipeline_gran; ++j) {
int rank_send = j + from_base;
int rank_recv = j + to_base;
GEN_IDX;
exchangeWith(global_grad_in + global_ptr[gidx_send] * d_model,
global_expert_count[idx_send] * !stored_models[idx_self], rank_send,
grad_in + local_ptr[idx_recv] * d_model,
local_expert_count[idx_recv] * !stored_models[idx_recv], rank_recv,
d_model, smgr->stream(num_expert), smgr->ncclcomm);
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
}
smgr->sync(num_expert + 1);
checkCudaErrors(hipGetLastError());
delete [] local_ptr;
delete [] global_ptr;
delete [] local_global_ptr;
checkCudaErrors(hipGetLastError());
for (long i = 0; i < n_groups; ++i) {
hipEventDestroy(input_ready[i]);
hipEventDestroy(output_ready[i]);
hipEventDestroy(output_torch_ready[i]);
}
delete [] input_ready;
delete [] output_ready;
delete [] output_torch_ready;
for (long i = 0; i < num_expert; ++i) {
if (stored_models[i + rank * num_expert]) {
hipEventDestroy(evt_reduce[i]);
}
}
delete [] evt_reduce;
}
#endif // SMART_SCHEDULE_H
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#pragma once
#ifndef FASTER_STATUS_H
#define FASTER_STATUS_H
int isSmartSchEnabled();
void setSmartSchEnabled(int);
#endif // FASTER_STATUS_H
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include <iostream>
#include <vector>
#include <torch/csrc/autograd/custom_function.h>
#include <torch/extension.h>
// global_exchange
#ifdef FMOE_USE_NCCL
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13))
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#else
#include <c10d/ProcessGroupNCCL.hpp>
#endif
torch::Tensor _expert_exchange(
torch::Tensor local_expert_count,
long n_expert, long n_workers);
torch::Tensor _global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers);
torch::Tensor _global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers);
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2)
void _ensure_nccl(c10d::ProcessGroup& p, torch::Tensor t);
#else
void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t);
#endif // TORCH_VERSION
#endif // FMOE_USE_NCCL
// local_exchange
void _assign_pos(
torch::Tensor cum_count,
torch::Tensor gate,
torch::Tensor pos);
void _expert_count(
torch::Tensor gate_idx,
torch::Tensor expert_count);
// parallel_linear
torch::Tensor _linear_forward(
torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight,
at::optional<torch::Tensor> bias
);
std::vector<torch::Tensor> _linear_backward(
torch::Tensor grad_output_buf,
torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight,
at::optional<torch::Tensor> bias
);
// balancing
torch::Tensor _limit_by_capacity(
torch::Tensor expert_count, torch::Tensor capacity,
long n_expert, long n_experts);
torch::Tensor _prune_gate_by_capacity(
torch::Tensor gate_idx, torch::Tensor expert_count,
long n_expert, long n_worker);
std::vector<torch::Tensor> _swipe_once(
torch::Tensor gate_idx, torch::Tensor capacity_tensor,
long n_expert, long n_worker, long bias);
// smart scheduling
std::vector<torch::Tensor> _smart_sch_forward(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
torch::Tensor stored_models,
long global_batch_size,
long expert_size,
long n_workers,
py::function forward_fn,
py::function get_param_fn,
py::function stash_fn,
py::function pop_fn);
torch::Tensor _smart_sch_backward(
torch::Tensor grad_out,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
torch::Tensor stored_models,
long buf_batch_size,
long global_batch_size,
long n_workers,
py::function backward_fn,
py::function stash_fn,
py::function pop_fn,
py::function collect_fn,
py::function set_grad_fn);
void _reduce_grad(
torch::Tensor t,
long root,
long expert_size);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifdef FMOE_USE_NCCL
m.def("expert_exchange", &_expert_exchange, "FastMoE expert exchange (CUDA)");
m.def("global_scatter", &_global_scatter, "FastMoE global scatter (CUDA)");
m.def("global_gather", &_global_gather, "FastMoE global gather (CUDA)");
m.def("ensure_nccl", &_ensure_nccl, "FastMoE ensure torch nccl comm");
m.def("swipe_once", &_swipe_once, "SWIPE balance strategy(CUDA)");
m.def("smart_sch_forward", &_smart_sch_forward, "E2E MoE layer forward with smart scheduling");
m.def("smart_sch_backward", &_smart_sch_backward, "E2E MoE layer backward with smart scheduling");
m.def("reduce_grad", &_reduce_grad, "Reduce gradients over FastMoE's communication stream");
#endif
m.def("expert_count", &_expert_count, "FastMoE count gate indices (CUDA)");
m.def("assign_pos", &_assign_pos, "FastMoE assign pos by gate (CUDA)");
m.def("linear_forward", &_linear_forward, "FastMoE forward (CUDA)");
m.def("linear_backward", &_linear_backward, "FastMoE backward (CUDA)");
m.def("limit_by_capacity", &_limit_by_capacity, "FastMoE limit experts by capacity(CUDA)");
m.def("prune_gate_by_capacity", &_prune_gate_by_capacity, "FastMoE prune gate by capacity(CUDA)");
}
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "../hip/global_exchange.h"
#include "../hip/utils/fmoe_utils.h"
#include <torch/extension.h>
#ifdef FMOE_USE_NCCL
#include <rccl/rccl.h>
void fmoe_cuda_expert_exchange_impl(
const long* local_expert_count,
long* global_expert_count,
int n_expert, int world_size,
CudaStreamManager* smgr) {
NCCL_SAFE_CALL(ncclGroupStart());
for (int i = 0; i < world_size; ++i) {
NCCL_SAFE_CALL(ncclSend(
local_expert_count + n_expert * i,
n_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->torchStream()));
NCCL_SAFE_CALL(ncclRecv(
global_expert_count + n_expert * i,
n_expert,
ncclInt64,
i,
smgr->ncclcomm,
smgr->torchStream()));
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
torch::Tensor _expert_exchange(
torch::Tensor local_expert_count,
long n_expert, long n_workers) {
auto global_expert_count = torch::empty_like(local_expert_count);
auto smgr = getCudaStreamManager(local_expert_count.device().index());
fmoe_cuda_expert_exchange_impl(
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
n_expert, n_workers,
smgr);
return global_expert_count;
}
torch::Tensor _global_scatter(
torch::Tensor input_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers) {
CHECK_INPUT(input_buf);
auto n_expert = local_expert_count.size(0) / n_workers;
auto in_feat = input_buf.size(1);
auto global_input_buf = input_buf.new_empty({batch_size, in_feat});
auto smgr = getCudaStreamManager(input_buf.device().index());
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
input_buf.scalar_type(), "fmoe_cuda_global_scatter", ([&] {
fmoe_cuda_global_scatter_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
global_input_buf.data_ptr<scalar_t>(),
in_feat, n_expert, n_workers,
smgr
);
}));
return global_input_buf;
}
torch::Tensor _global_gather(
torch::Tensor output_buf,
torch::Tensor local_expert_count,
torch::Tensor global_expert_count,
long batch_size, long n_workers) {
CHECK_INPUT(output_buf);
auto n_expert = local_expert_count.size(0) / n_workers;
auto out_feat = output_buf.size(1);
auto local_output_buf = output_buf.new_empty({batch_size, out_feat});
auto smgr = getCudaStreamManager(output_buf.device().index());
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
output_buf.scalar_type(), "fmoe_cuda_global_gather", ([&] {
fmoe_cuda_global_gather_impl<scalar_t>(
output_buf.data_ptr<scalar_t>(),
local_expert_count.data_ptr<long>(),
global_expert_count.data_ptr<long>(),
local_output_buf.data_ptr<scalar_t>(),
out_feat, n_expert, n_workers,
smgr
);
}));
return local_output_buf;
}
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13))
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#else
#include <c10d/ProcessGroupNCCL.hpp>
#endif
class HackNCCLGroup: public c10d::ProcessGroupNCCL {
public:
ncclComm_t getcomm(at::Device dev) {
ncclUniqueId ncclID;
int rank = getRank();
if (rank == 0) {
ncclGetUniqueId(&ncclID);
}
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 12))
broadcastUniqueNCCLID(&ncclID,
false,
"fastmoe_nccl_comm",
rank);
#elif defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 8))
broadcastUniqueNCCLID(&ncclID,
c10d::OpType::SEND,
"fastmoe_nccl_comm",
rank);
#else
broadcastUniqueNCCLID(&ncclID);
#endif
ncclComm_t comm;
NCCL_SAFE_CALL(ncclCommInitRank(&comm, getSize(), ncclID, rank));
return comm;
}
};
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2)
void _ensure_nccl(c10d::ProcessGroup& p, torch::Tensor t) {
#else
void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
#endif // TORCH_VERSION
auto smgr = getCudaStreamManager(t.device().index());
if (smgr->ncclgood) {
return;
}
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2)
HackNCCLGroup* h = (HackNCCLGroup*)(void*)
(p.getBackend(c10d::ProcessGroup::NCCL).get());
#else
HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p;
#endif // TORCH_VERSION
smgr->ncclcomm = h->getcomm(t.device());
if (smgr->ncclcomm != 0) {
smgr->ncclgood = 1;
} else {
std::cerr << "Nccl initialization failed\n";
}
}
#endif // FMOE_USE_NCCL
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "../hip/stream_manager.h"
#ifdef FMOE_USE_NCCL
void fmoe_cuda_expert_exchange_impl(
const long* local_expert_count,
long* global_expert_count,
int n_expert, int world_size,
CudaStreamManager* smgr);
template<typename scalar_t>
void fmoe_cuda_global_scatter_impl(
const scalar_t* local_input_buf,
const long* local_expert_count,
const long* global_expert_count,
scalar_t* input_buf,
size_t in_feat, size_t n_expert, size_t world_size,
CudaStreamManager* smgr) {
// assert world_size > 1
int recv_ptr = 0;
/* TODO: may save for backward */
long*expert_ptr = new long[n_expert * world_size];
expert_ptr[0] = 0;
for (size_t i = 1; i < n_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (size_t i = 0; i < n_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (size_t j = 0; j < world_size; ++j) {
int idx = i + j * n_expert;
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
local_input_buf + expert_ptr[idx] * in_feat,
local_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->torchStream()));
}
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
input_buf + recv_ptr * in_feat,
global_expert_count[idx] * in_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->torchStream()));
recv_ptr += global_expert_count[idx];
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
}
template<typename scalar_t>
void fmoe_cuda_global_gather_impl(
const scalar_t* output_buf,
const long* local_expert_count,
const long* global_expert_count,
scalar_t* local_output_buf,
size_t out_feat, size_t n_expert, size_t world_size,
CudaStreamManager* smgr) {
long send_ptr = 0;
/* TODO: may save for backward */
long *expert_ptr = new long[n_expert * world_size];
expert_ptr[0] = 0;
for (size_t i = 1; i < n_expert * world_size; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + local_expert_count[i - 1];
}
for (size_t i = 0; i < n_expert; ++i) {
NCCL_SAFE_CALL(ncclGroupStart());
for (size_t j = 0; j < world_size; ++j) {
int idx = i + j * n_expert;
if (global_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend(
output_buf + send_ptr * out_feat,
global_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->torchStream()));
send_ptr += global_expert_count[idx];
}
if (local_expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat,
local_expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
smgr->ncclcomm,
smgr->torchStream()));
}
}
NCCL_SAFE_CALL(ncclGroupEnd());
}
delete [] expert_ptr;
}
#endif // FMOE_USE_NCCL
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
#include "../hip/stream_manager.h"
#include "../hip/utils/helper_cuda.h"
#include "../hip/utils/fmoe_utils.h"
__global__
void assign_pos_kernel(int* cum_count, const long* gate, long* pos,
size_t numel, size_t topk) {
size_t idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < numel) {
long gate_idx = gate[idx];
if (gate_idx > -1) {
int p = atomicSub(cum_count + gate_idx, 1);
pos[p - 1] = (long)idx;
}
}
}
void fmoe_cuda_assign_pos_impl(
int* cum_count, const long* gate, long* pos,
const size_t batch_size, const size_t topk,
CudaStreamManager* smgr) {
size_t numel = batch_size * topk;
hipLaunchKernelGGL(( assign_pos_kernel)
, dim3(CEIL(numel, 256)), dim3(256), 0, smgr->torchStream(),
cum_count, gate, pos, numel, topk);
}
#define PERTHREAD_EXPERTS 256
#ifdef FMOE_USE_HIP
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
__global__
void expert_count_kernel(const long* gate_idx, int* expert_count,
const size_t batch_size, const size_t n_expert) {
int res_tmp[PERTHREAD_EXPERTS] = {0};
long expert_min = blockIdx.x * PERTHREAD_EXPERTS;
long expert_max = expert_min + PERTHREAD_EXPERTS;
if (expert_max > n_expert) {
expert_max = n_expert;
}
for (int i = threadIdx.x; i < batch_size; i += blockDim.x) {
long idx = gate_idx[i];
if (idx == -1) {
continue;
}
if (idx < expert_min || idx >= expert_max) {
continue;
}
res_tmp[idx - expert_min] += 1;
}
for (int i = expert_min; i < expert_max; ++i) {
int x = res_tmp[i - expert_min];
#pragma unroll
for (int j = 1; j < WARP_SIZE; j <<= 1) {
#ifdef FMOE_USE_HIP
x = x + __shfl_down(x, j);
#else
x = x + __shfl_down_sync(-1u, x, j);
#endif
}
if (threadIdx.x % WARP_SIZE == 0) {
atomicAdd(expert_count + i, x);
}
}
}
void fmoe_cuda_expert_count_impl(
const long* gate_idx, int* expert_count,
const size_t batch_size, const size_t n_expert,
CudaStreamManager* smgr) {
hipLaunchKernelGGL(( expert_count_kernel)
, dim3(CEIL(n_expert, PERTHREAD_EXPERTS)), dim3(256), 0, smgr->torchStream(),
gate_idx, expert_count, batch_size, n_expert);
}
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "../hip/local_exchange.cuh"
#include "../hip/utils/fmoe_utils.h"
#include <torch/extension.h>
void _assign_pos(
torch::Tensor cum_count,
torch::Tensor gate,
torch::Tensor pos) {
auto smgr = getCudaStreamManager(cum_count.device().index());
auto gate_shp = gate.sizes();
size_t batch_size = gate_shp[0], topk = 1;
if (gate_shp.size() == 2) {
topk = gate_shp[1];
}
fmoe_cuda_assign_pos_impl(
cum_count.data_ptr<int>(),
gate.data_ptr<long>(),
pos.data_ptr<long>(),
batch_size, topk, smgr);
}
void _expert_count(
torch::Tensor gate_idx,
torch::Tensor expert_count) {
auto smgr = getCudaStreamManager(gate_idx.device().index());
auto batch_size = gate_idx.numel();
auto n_expert = expert_count.numel();
fmoe_cuda_expert_count_impl(
gate_idx.data_ptr<long>(),
expert_count.data_ptr<int>(),
batch_size, n_expert, smgr);
}
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
#include "../hip/stream_manager.h"
#include "../hip/utils/cublas_wrapper.h"
/*
This function is to be called with one block per each column
*/
template <typename scalar_t>
__global__
void column_reduce(const scalar_t * matrix, scalar_t * result,
int m /* lines */, int n /* columns*/) {
// https://stackoverflow.com/questions/27570552/templated-cuda-kernel-with-dynamic-shared-memory
extern __shared__ unsigned char my_smem[];
scalar_t *sdata = reinterpret_cast<scalar_t *>(my_smem);
// normal tid
int tid = threadIdx.x + threadIdx.y * blockDim.x;
// transposed tid for shared memory
int new_tid = threadIdx.y + threadIdx.x * blockDim.y;
// true x value in the matrix
int real_x = threadIdx.x + blockDim.x * blockIdx.x;
int i = real_x + n * threadIdx.y;
const int it = n*blockDim.y;
int offset = it;
float accumulator = 0;
if (threadIdx.y < m && real_x < n) {
// store all the values from this column in a warped way
accumulator = matrix[i];
while (i + offset < n*m) {
accumulator += matrix[i + offset];
offset += it;
}
}
// save column reduction data in a transposed way
sdata[new_tid] = accumulator;
__syncthreads();
for (size_t t= 16; t > 0; t>>=1) {
if (tid < 32 * 32 - 16)
sdata[tid] += sdata[tid + t];
__syncthreads();
}
if (threadIdx.y == 0 && real_x < n)
result[real_x] = sdata[new_tid];
}
template <typename scalar_t>
void fmoe_cuda_linear_forward_impl(
const scalar_t* input_buf,
const scalar_t* weight,
const long* expert_count,
scalar_t* output_buf,
const bool has_bias,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert,
CudaStreamManager* smgr) {
scalar_t alpha = 1, beta = has_bias ? 1 : 0;
smgr->syncTorch();
for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) {
continue;
}
// Use T(B) x T(A) = T(C) to produce row-major C
//change alpha beta dtype
checkCudaErrors(cublasXgemm(
smgr->handle(i),
HIPBLAS_OP_T,
HIPBLAS_OP_N,
out_feat, expert_count[i], in_feat,
&alpha,
weight + i * in_feat * out_feat, in_feat,
input_buf + ptr * in_feat, in_feat,
&beta,
output_buf + out_feat * ptr, out_feat
));
ptr += expert_count[i];
}
smgr->sync(num_expert);
}
template <typename scalar_t>
void fmoe_cuda_linear_backward_impl(
const scalar_t* grad_output_buf,
const scalar_t* input_buf,
const scalar_t* weight,
const long* expert_count,
scalar_t* grad_input_buf,
scalar_t* grad_weight,
scalar_t* grad_bias,
const bool has_bias,
const size_t batch_size,
const size_t in_feat,
const size_t out_feat,
const size_t num_expert,
CudaStreamManager* smgr) {
smgr->syncTorch();
scalar_t alpha = 1, beta = 0;
// bias
dim3 block_threads(32, 32);
dim3 grid_threads(out_feat / 32 + (out_feat % 32 ? 1 : 0), 1);
for (int i = 0, ptr = 0; i < num_expert; ++i) {
if (expert_count[i] == 0) {
hipMemset(grad_weight + i * in_feat * out_feat, 0,
sizeof(scalar_t) * in_feat * out_feat);
hipMemset(grad_bias + i * out_feat, 0, sizeof(scalar_t) * out_feat);
continue;
}
// Use T(B) x T(A) = T(C) to produce row-major C
// Backward input: g_i = w @ g_o
checkCudaErrors(cublasXgemm(
smgr->handle(i),
HIPBLAS_OP_N,
HIPBLAS_OP_N,
in_feat, expert_count[i], out_feat,
&alpha,
weight + i * in_feat * out_feat, in_feat,
grad_output_buf + ptr * out_feat, out_feat,
&beta,
grad_input_buf + in_feat * ptr, in_feat
));
// Backward weight: g_w = i @ g_o
checkCudaErrors(cublasXgemm(
smgr->handle(i),
HIPBLAS_OP_N,
HIPBLAS_OP_T,
in_feat, out_feat, expert_count[i],
&alpha,
input_buf + in_feat * ptr, in_feat,
grad_output_buf + ptr * out_feat, out_feat,
&beta,
grad_weight + i * in_feat * out_feat, in_feat
));
if (has_bias) {
hipLaunchKernelGGL(( column_reduce)
, dim3(grid_threads), dim3(block_threads), sizeof(scalar_t)*1024, smgr->stream(i),
grad_output_buf + ptr * out_feat,
grad_bias + i * out_feat,
expert_count[i],
out_feat
);
}
ptr += expert_count[i];
}
smgr->sync(num_expert);
}
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "../hip/parallel_linear.cuh"
#include "../hip/utils/fmoe_utils.h"
#include <torch/extension.h>
torch::Tensor _linear_forward(
torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight,
at::optional<torch::Tensor> bias
) {
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG
printf("[forward] expert=%ld, in_feat (d_model)=%ld, out_feat (d_ffn)=%ld\n",
num_expert, in_feat, out_feat);
#endif
torch::Tensor output;
if (bias.has_value()) {
output = bias.value().repeat_interleave(expert_count.to(bias.value().device()), 0);
} else{
auto out_options = torch::TensorOptions()
.device(input_buf.device())
.dtype(input_buf.dtype());
output = torch::empty({batch_size, out_feat}, out_options);
}
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
input_buf.scalar_type(), "moe_forward_cuda",
([&] {
fmoe_cuda_linear_forward_impl<scalar_t>(
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
expert_count.data_ptr<long>(),
output.data_ptr<scalar_t>(),
bias.has_value(),
in_feat,
out_feat,
num_expert,
smgr
);
}));
return output;
}
std::vector<torch::Tensor> _linear_backward(
torch::Tensor grad_output_buf,
torch::Tensor input_buf,
torch::Tensor expert_count,
torch::Tensor weight,
at::optional<torch::Tensor> bias
) {
auto smgr = getCudaStreamManager(input_buf.device().index());
const auto batch_size = input_buf.size(0);
const auto num_expert = weight.size(0);
const auto out_feat = weight.size(1);
const auto in_feat = weight.size(2);
#ifdef MOE_DEBUG
printf("[backward] b=%ld, expert=%ld, in_feat (d_model)=%ld, "
"out_feat (d_ffn)=%ld\n",
batch_size, num_expert, in_feat, out_feat);
#endif
auto grad_input_buf = grad_output_buf.new_empty({batch_size, in_feat});
auto grad_weight = grad_output_buf.new_empty({num_expert, out_feat, in_feat});
auto grad_bias = grad_output_buf.new_empty({num_expert, out_feat});
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
input_buf.scalar_type(), "moe_cuda_backward", ([&] {
fmoe_cuda_linear_backward_impl<scalar_t>(
grad_output_buf.data_ptr<scalar_t>(),
input_buf.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
expert_count.data_ptr<long>(),
grad_input_buf.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(),
grad_bias.data_ptr<scalar_t>(),
bias.has_value(),
batch_size,
in_feat,
out_feat,
num_expert,
smgr
);
}));
return {grad_input_buf, grad_weight, grad_bias};
}
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
#include <unordered_map>
#include <mutex>
#include <cassert>
#include <thread>
#include <iostream>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include <ATen/hip/HIPContext.h>
#include "../hip/fastermoe/status.h"
#include "../hip/stream_manager.h"
#define SMGR_N_STREAMS 16
hipStream_t CudaStreamManager::stream(size_t idx) {
if (this->use_default) {
return c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
}
return this->streams[idx % SMGR_N_STREAMS];
}
hipStream_t CudaStreamManager::torchStream() {
return c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
}
hipblasHandle_t CudaStreamManager::handle(size_t idx) {
if (this->use_default) {
return at::cuda::getCurrentCUDABlasHandle();
}
return this->handles[idx % SMGR_N_STREAMS];
}
void CudaStreamManager::syncTorch() {
hipStreamSynchronize(this->torchStream());
}
void CudaStreamManager::sync(int idx) {
if (this->use_default) {
return;
}
for (int i = 0; i < idx && i < SMGR_N_STREAMS; ++i) {
hipStreamSynchronize(streams[i]);
}
}
void CudaStreamManager::setup(const int device) {
#ifdef FMOE_USE_NCCL
this->ncclgood = 0;
#endif
this->device = device;
checkCudaErrors(hipSetDevice(device));
streams = new hipStream_t[SMGR_N_STREAMS];
handles = new hipblasHandle_t[SMGR_N_STREAMS];
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
// SHOULD NOT USE: hipStreamCreate(...)
// more details in
// https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html
checkCudaErrors(hipStreamCreateWithFlags(streams + i,
hipStreamNonBlocking));
checkCudaErrors(hipblasCreate(handles + i));
hipblasSetStream(handles[i], streams[i]);
}
}
void CudaStreamManager::destroy() {
for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
checkCudaErrors(hipStreamDestroy(streams[i]));
checkCudaErrors(hipblasDestroy(handles[i]));
}
delete[] streams;
delete[] handles;
}
std::unordered_map<int, CudaStreamManager*> smgrs;
std::mutex smgr_mtx;
CudaStreamManager* getCudaStreamManager(const int device) {
auto it = smgrs.find(device);
if (it == smgrs.end()) {
smgr_mtx.lock();
it = smgrs.find(device);
if (it == smgrs.end()) {
auto smgr = new CudaStreamManager(device);
smgrs.insert(std::pair<int, CudaStreamManager*>(device, smgr));
smgr_mtx.unlock();
return smgr;
} else {
smgr_mtx.unlock();
}
}
return it->second;
}
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
#ifndef CUDA_STREAM_MANAGER_H
#define CUDA_STREAM_MANAGER_H
#include "../hip/utils/helper_cuda.h"
#ifdef FMOE_USE_NCCL
#include <rccl/rccl.h>
#define NCCL_SAFE_CALL(__fn__) { \
auto __res__ = __fn__; \
if (__res__ != ncclSuccess) { \
fprintf(stderr, "NCCL Error at %s:%d value %d\n", __FILE__, __LINE__, __res__); \
exit(-1); \
} \
}
#endif
class CudaStreamManager {
public:
int device;
hipblasHandle_t* handles;
hipStream_t* streams;
bool use_default;
#ifdef FMOE_USE_NCCL
char ncclgood;
ncclComm_t ncclcomm;
#endif
public:
CudaStreamManager(int device_): device(device_), use_default(false) {
this->setup(device);
}
void setup(int);
void sync(int=0);
void syncTorch();
void destroy();
hipStream_t torchStream();
hipStream_t stream(size_t=0);
hipblasHandle_t handle(size_t=0);
~CudaStreamManager() {
this->destroy();
}
};
CudaStreamManager* getCudaStreamManager(const int device);
#endif // CUDA_STREAM_MANAGER
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include <iostream>
#ifndef CUBLAS_WRAPPER_H
#define CUBLAS_WRAPPER_H
//#include </opt/dtk/include/rocblas/internal/rocblas-types.h>
#include <hipblas/hipblas.h>
#include <c10/util/Half.h>
#include <c10/util/BFloat16.h>
#include </opt/dtk/hip/include/hip/amd_detail/amd_hip_bf16.h>
//#include </opt/dtk/include/rocblas/internal/rocblas-types.h>
inline hipblasStatus_t cublasXgemmBatched(hipblasHandle_t handle,
hipblasOperation_t transa,
hipblasOperation_t transb,
int m, int n, int k,
const float *alpha,
const float *Aarray[], int lda,
const float *Barray[], int ldb,
const float *beta,
float *Carray[], int ldc,
int batchCount) {
return hipblasSgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
}
inline hipblasStatus_t cublasXgemmBatched(hipblasHandle_t handle,
hipblasOperation_t transa,
hipblasOperation_t transb,
int m, int n, int k,
const double *alpha,
const double *Aarray[], int lda,
const double *Barray[], int ldb,
const double *beta,
double *Carray[], int ldc,
int batchCount) {
return hipblasDgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
}
inline hipblasStatus_t cublasXgemmBatched(hipblasHandle_t handle,
hipblasOperation_t transa,
hipblasOperation_t transb,
int m, int n, int k,
const __half *alpha,
const __half *Aarray[], int lda,
const __half *Barray[], int ldb,
const __half *beta,
__half *Carray[], int ldc,
int batchCount) {
#if defined (FMOE_USE_HIP) && defined(__CUDA_MIX_HIP__)
//#ifdef FMOE_USE_HIP
return hipblasHgemmBatched(handle, transa, transb, m, n, k, (const rocblas_half*)alpha, (const rocblas_half* const*)Aarray, lda, (const rocblas_half* const*)Barray, ldb, (const rocblas_half*)beta, (rocblas_half* const*)Carray, ldc, batchCount);
#else
// return hipblasHgemmBatched(handle, transa, transb, m, n, k, alpha, Aarray, lda, Barray, ldb, beta, Carray, ldc, batchCount);
return hipblasHgemmBatched(handle, transa, transb, m, n, k, (const hipblasHalf*)alpha, (const hipblasHalf* const*)Aarray, lda, (const hipblasHalf* const*)Barray, ldb, (const hipblasHalf*)beta, (hipblasHalf* const*)Carray, ldc, batchCount);
#endif
}
inline hipblasStatus_t cublasXgemm(hipblasHandle_t handle,
hipblasOperation_t transa, hipblasOperation_t transb,
int m, int n, int k,
const float *alpha,
const float *A, int lda,
const float *B, int ldb,
const float *beta,
float *C, int ldc) {
return hipblasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
inline hipblasStatus_t cublasXgemm(hipblasHandle_t handle,
hipblasOperation_t transa, hipblasOperation_t transb,
int m, int n, int k,
const double *alpha,
const double *A, int lda,
const double *B, int ldb,
const double *beta,
double *C, int ldc) {
return hipblasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
inline hipblasStatus_t cublasXgemm(hipblasHandle_t handle,
hipblasOperation_t transa, hipblasOperation_t transb,
int m, int n, int k,
const __half *alpha,
const __half *A, int lda,
const __half *B, int ldb,
const __half *beta,
__half *C, int ldc) {
//#ifdef FMOE_USE_HIP
#if defined (FMOE_USE_HIP) && defined(__CUDA_MIX_HIP__)
return hipblasHgemm(handle, transa, transb, m, n, k, (const rocblas_half*)alpha, (const rocblas_half* )A, lda, (const rocblas_half* )B, ldb, (const rocblas_half*)beta, (rocblas_half* )C, ldc);
#else
// return hipblasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
return hipblasHgemm(handle, transa, transb, m, n, k, (const hipblasHalf*)alpha, (const hipblasHalf*)A, lda, (const hipblasHalf*)B, ldb, (const hipblasHalf*)beta, (hipblasHalf*)C, ldc);
#endif
}
inline hipblasStatus_t cublasXgemm(hipblasHandle_t handle,
hipblasOperation_t transa, hipblasOperation_t transb,
int m, int n, int k,
const c10::Half *alpha,
const c10::Half *A, int lda,
const c10::Half *B, int ldb,
const c10::Half *beta,
c10::Half *C, int ldc) {
//#ifdef FMOE_USE_HIP
#if defined (FMOE_USE_HIP) && defined(__CUDA_MIX_HIP__)
return hipblasHgemm(handle, transa, transb, m, n, k,
(const rocblas_half*)alpha,
(const rocblas_half*)A, lda,
(const rocblas_half*)B, ldb,
(const rocblas_half*)beta,
(rocblas_half*)C, ldc);
#else
return hipblasHgemm(handle, transa, transb, m, n, k,
//(const __half*)alpha,
(const hipblasHalf*)alpha,
//(const __half*)A, lda,
(const hipblasHalf*)A, lda,
//(const __half*)B, ldb,
(const hipblasHalf*)B, ldb,
//(const __half*)beta,
(const hipblasHalf*)beta,
//(__half*)C, ldc);
(hipblasHalf*)C, ldc);
#endif
}
inline hipblasStatus_t cublasXgemm(hipblasHandle_t handle,
hipblasOperation_t transa, hipblasOperation_t transb,
int m, int n, int k,
const c10::BFloat16 *alpha,
//const void *alpha,
const c10::BFloat16 *A, int lda,
const c10::BFloat16 *B, int ldb,
const c10::BFloat16 *beta,
//const void *beta,
c10::BFloat16 *C, int ldc) {
//#ifdef FMOE_USE_HIP
#if defined (FMOE_USE_HIP) && defined(__CUDA_MIX_HIP__)
// TODO: Support bf16 for HIP
assert(false);
#else
//const float alpha_fp32(*alpha), beta_fp32(*beta);
hipblasDatatype_t datatype_C = HIPBLAS_R_16B;
float alpha_ = static_cast<float>(*alpha);
float beta_ = static_cast<float>(*beta);
return hipblasGemmEx(handle, transa, transb, m, n, k,
//(const float*)&alpha_fp32,
//(const void*)A, datatype_C, lda,
//(const void*)B, datatype_C, ldb,
//(const float*)&beta_fp32,
//(void*)C, datatype_C, ldc,
reinterpret_cast<const float*>(&alpha_),
//alpha,
reinterpret_cast<const void*>(A), datatype_C, lda,
reinterpret_cast<const void*>(B), datatype_C, ldb,
reinterpret_cast<const float*>(&beta_),
//beta,
reinterpret_cast<void*>(C), datatype_C, ldc,
HIPBLAS_R_32F,
HIPBLAS_GEMM_DEFAULT);
#endif
}
#endif // CUBLAS_WRAPPER_H
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifndef FMOE_UTILS_H
#define FMOE_UTILS_H
#define CHECK_CUDA(x) AT_ASSERTM(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
#endif // FMOE_UTILS_H
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
/* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of NVIDIA CORPORATION nor the names of its
* contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
////////////////////////////////////////////////////////////////////////////////
// These are CUDA Helper functions for initialization and error checking
// This file is clipped from the original header file by laekov
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <stdio.h>
#include <stdlib.h>
#ifndef HELPER_CUDA_H
#define HELPER_CUDA_H
static const char *_cudaGetErrorEnum(hipError_t error) {
return hipGetErrorName(error);
}
#ifdef CUDA_DRIVER_API
// CUDA Driver API errors
static const char *_cudaGetErrorEnum(hipError_t error) {
static char unknown[] = "<unknown>";
const char *ret = NULL;
hipGetErrorName(error, &ret);
return ret ? ret : unknown;
}
#endif
#if defined(FMOE_USE_HIP) && defined(__CUDA_MIX_HIP__)
static const char *_cudaGetErrorEnum(hipblasStatus_t error) {
switch (error) {
case HIPBLAS_STATUS_SUCCESS:
return "HIPBLAS_STATUS_SUCCESS";
case HIPBLAS_STATUS_NOT_INITIALIZED:
return "HIPBLAS_STATUS_NOT_INITIALIZED";
case HIPBLAS_STATUS_ARCH_MISMATCH:
return "HIPBLAS_STATUS_ARCH_MISMATCH";
case HIPBLAS_STATUS_INVALID_VALUE:
return "HIPBLAS_STATUS_INVALID_VALUE:";
case rocblas_status_invalid_size:
return "rocblas_status_invalid_size";
case HIPBLAS_STATUS_ALLOC_FAILED:
return "HIPBLAS_STATUS_ALLOC_FAILED";
case HIPBLAS_STATUS_INTERNAL_ERROR:
return "HIPBLAS_STATUS_INTERNAL_ERROR";
case rocblas_status_perf_degraded:
return "rocblas_status_perf_degraded";
case rocblas_status_size_query_mismatch:
return "rocblas_status_size_query_mismatch";
case rocblas_status_size_increased:
return "rocblas_status_size_increased";
case rocblas_status_size_unchanged:
return "rocblas_status_size_unchanged";
case rocblas_status_invalid_value:
return "rocblas_status_invalid_value";
case rocblas_status_continue:
return "rocblas_status_continue";
}
return "<unknown>";
}
#else
// cuBLAS API errors
static const char *_cudaGetErrorEnum(hipblasStatus_t error) {
switch (error) {
case HIPBLAS_STATUS_SUCCESS:
return "HIPBLAS_STATUS_SUCCESS";
case HIPBLAS_STATUS_NOT_INITIALIZED:
return "HIPBLAS_STATUS_NOT_INITIALIZED";
case HIPBLAS_STATUS_ALLOC_FAILED:
return "HIPBLAS_STATUS_ALLOC_FAILED";
case HIPBLAS_STATUS_INVALID_VALUE:
return "HIPBLAS_STATUS_INVALID_VALUE";
case HIPBLAS_STATUS_ARCH_MISMATCH:
return "HIPBLAS_STATUS_ARCH_MISMATCH";
case HIPBLAS_STATUS_MAPPING_ERROR:
return "HIPBLAS_STATUS_MAPPING_ERROR";
case HIPBLAS_STATUS_EXECUTION_FAILED:
return "HIPBLAS_STATUS_EXECUTION_FAILED";
case HIPBLAS_STATUS_INTERNAL_ERROR:
return "HIPBLAS_STATUS_INTERNAL_ERROR";
case HIPBLAS_STATUS_NOT_SUPPORTED:
return "HIPBLAS_STATUS_NOT_SUPPORTED";
//case CUBLAS_STATUS_LICENSE_ERROR:
// return "HIPBLAS_STATUS_INTERNAL_ERROR";
}
return "<unknown>";
}
#endif
#ifdef _CUFFT_H_
// cuFFT API errors
static const char *_cudaGetErrorEnum(hipfftResult error) {
switch (error) {
case HIPFFT_SUCCESS:
return "HIPFFT_SUCCESS";
case HIPFFT_INVALID_PLAN:
return "HIPFFT_INVALID_PLAN";
case HIPFFT_ALLOC_FAILED:
return "HIPFFT_ALLOC_FAILED";
case HIPFFT_INVALID_TYPE:
return "HIPFFT_INVALID_TYPE";
case HIPFFT_INVALID_VALUE:
return "HIPFFT_INVALID_VALUE";
case HIPFFT_INTERNAL_ERROR:
return "HIPFFT_INTERNAL_ERROR";
case HIPFFT_EXEC_FAILED:
return "HIPFFT_EXEC_FAILED";
case HIPFFT_SETUP_FAILED:
return "HIPFFT_SETUP_FAILED";
case HIPFFT_INVALID_SIZE:
return "HIPFFT_INVALID_SIZE";
case HIPFFT_UNALIGNED_DATA:
return "HIPFFT_UNALIGNED_DATA";
case HIPFFT_INCOMPLETE_PARAMETER_LIST:
return "HIPFFT_INCOMPLETE_PARAMETER_LIST";
case HIPFFT_INVALID_DEVICE:
return "HIPFFT_INVALID_DEVICE";
case HIPFFT_PARSE_ERROR:
return "HIPFFT_PARSE_ERROR";
case HIPFFT_NO_WORKSPACE:
return "HIPFFT_NO_WORKSPACE";
case HIPFFT_NOT_IMPLEMENTED:
return "HIPFFT_NOT_IMPLEMENTED";
case HIPFFT_LICENSE_ERROR:
return "HIPFFT_LICENSE_ERROR";
case HIPFFT_NOT_SUPPORTED:
return "HIPFFT_NOT_SUPPORTED";
}
return "<unknown>";
}
#endif
#ifdef CUSPARSEAPI
// cuSPARSE API errors
static const char *_cudaGetErrorEnum(hipsparseStatus_t error) {
switch (error) {
case HIPSPARSE_STATUS_SUCCESS:
return "HIPSPARSE_STATUS_SUCCESS";
case HIPSPARSE_STATUS_NOT_INITIALIZED:
return "HIPSPARSE_STATUS_NOT_INITIALIZED";
case HIPSPARSE_STATUS_ALLOC_FAILED:
return "HIPSPARSE_STATUS_ALLOC_FAILED";
case HIPSPARSE_STATUS_INVALID_VALUE:
return "HIPSPARSE_STATUS_INVALID_VALUE";
case HIPSPARSE_STATUS_ARCH_MISMATCH:
return "HIPSPARSE_STATUS_ARCH_MISMATCH";
case HIPSPARSE_STATUS_MAPPING_ERROR:
return "HIPSPARSE_STATUS_MAPPING_ERROR";
case HIPSPARSE_STATUS_EXECUTION_FAILED:
return "HIPSPARSE_STATUS_EXECUTION_FAILED";
case HIPSPARSE_STATUS_INTERNAL_ERROR:
return "HIPSPARSE_STATUS_INTERNAL_ERROR";
case HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
return "HIPSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
}
return "<unknown>";
}
#endif
#ifdef CUSOLVER_COMMON_H_
// cuSOLVER API errors
static const char *_cudaGetErrorEnum(cusolverStatus_t error) {
switch (error) {
case CUSOLVER_STATUS_SUCCESS:
return "CUSOLVER_STATUS_SUCCESS";
case CUSOLVER_STATUS_NOT_INITIALIZED:
return "CUSOLVER_STATUS_NOT_INITIALIZED";
case CUSOLVER_STATUS_ALLOC_FAILED:
return "CUSOLVER_STATUS_ALLOC_FAILED";
case CUSOLVER_STATUS_INVALID_VALUE:
return "CUSOLVER_STATUS_INVALID_VALUE";
case CUSOLVER_STATUS_ARCH_MISMATCH:
return "CUSOLVER_STATUS_ARCH_MISMATCH";
case CUSOLVER_STATUS_MAPPING_ERROR:
return "CUSOLVER_STATUS_MAPPING_ERROR";
case CUSOLVER_STATUS_EXECUTION_FAILED:
return "CUSOLVER_STATUS_EXECUTION_FAILED";
case CUSOLVER_STATUS_INTERNAL_ERROR:
return "CUSOLVER_STATUS_INTERNAL_ERROR";
case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
case CUSOLVER_STATUS_NOT_SUPPORTED:
return "CUSOLVER_STATUS_NOT_SUPPORTED ";
case CUSOLVER_STATUS_ZERO_PIVOT:
return "CUSOLVER_STATUS_ZERO_PIVOT";
case CUSOLVER_STATUS_INVALID_LICENSE:
return "CUSOLVER_STATUS_INVALID_LICENSE";
}
return "<unknown>";
}
#endif
#ifdef CURAND_H_
// cuRAND API errors
static const char *_cudaGetErrorEnum(hiprandStatus_t error) {
switch (error) {
case HIPRAND_STATUS_SUCCESS:
return "HIPRAND_STATUS_SUCCESS";
case HIPRAND_STATUS_VERSION_MISMATCH:
return "HIPRAND_STATUS_VERSION_MISMATCH";
case HIPRAND_STATUS_NOT_INITIALIZED:
return "HIPRAND_STATUS_NOT_INITIALIZED";
case HIPRAND_STATUS_ALLOCATION_FAILED:
return "HIPRAND_STATUS_ALLOCATION_FAILED";
case HIPRAND_STATUS_TYPE_ERROR:
return "HIPRAND_STATUS_TYPE_ERROR";
case HIPRAND_STATUS_OUT_OF_RANGE:
return "HIPRAND_STATUS_OUT_OF_RANGE";
case HIPRAND_STATUS_LENGTH_NOT_MULTIPLE:
return "HIPRAND_STATUS_LENGTH_NOT_MULTIPLE";
case HIPRAND_STATUS_DOUBLE_PRECISION_REQUIRED:
return "HIPRAND_STATUS_DOUBLE_PRECISION_REQUIRED";
case HIPRAND_STATUS_LAUNCH_FAILURE:
return "HIPRAND_STATUS_LAUNCH_FAILURE";
case HIPRAND_STATUS_PREEXISTING_FAILURE:
return "HIPRAND_STATUS_PREEXISTING_FAILURE";
case HIPRAND_STATUS_INITIALIZATION_FAILED:
return "HIPRAND_STATUS_INITIALIZATION_FAILED";
case HIPRAND_STATUS_ARCH_MISMATCH:
return "HIPRAND_STATUS_ARCH_MISMATCH";
case HIPRAND_STATUS_INTERNAL_ERROR:
return "HIPRAND_STATUS_INTERNAL_ERROR";
}
return "<unknown>";
}
#endif
#ifdef NVJPEGAPI
// nvJPEG API errors
static const char *_cudaGetErrorEnum(nvjpegStatus_t error) {
switch (error) {
case NVJPEG_STATUS_SUCCESS:
return "NVJPEG_STATUS_SUCCESS";
case NVJPEG_STATUS_NOT_INITIALIZED:
return "NVJPEG_STATUS_NOT_INITIALIZED";
case NVJPEG_STATUS_INVALID_PARAMETER:
return "NVJPEG_STATUS_INVALID_PARAMETER";
case NVJPEG_STATUS_BAD_JPEG:
return "NVJPEG_STATUS_BAD_JPEG";
case NVJPEG_STATUS_JPEG_NOT_SUPPORTED:
return "NVJPEG_STATUS_JPEG_NOT_SUPPORTED";
case NVJPEG_STATUS_ALLOCATOR_FAILURE:
return "NVJPEG_STATUS_ALLOCATOR_FAILURE";
case NVJPEG_STATUS_EXECUTION_FAILED:
return "NVJPEG_STATUS_EXECUTION_FAILED";
case NVJPEG_STATUS_ARCH_MISMATCH:
return "NVJPEG_STATUS_ARCH_MISMATCH";
case NVJPEG_STATUS_INTERNAL_ERROR:
return "NVJPEG_STATUS_INTERNAL_ERROR";
}
return "<unknown>";
}
#endif
#ifdef NV_NPPIDEFS_H
// NPP API errors
static const char *_cudaGetErrorEnum(NppStatus error) {
switch (error) {
case NPP_NOT_SUPPORTED_MODE_ERROR:
return "NPP_NOT_SUPPORTED_MODE_ERROR";
case NPP_ROUND_MODE_NOT_SUPPORTED_ERROR:
return "NPP_ROUND_MODE_NOT_SUPPORTED_ERROR";
case NPP_RESIZE_NO_OPERATION_ERROR:
return "NPP_RESIZE_NO_OPERATION_ERROR";
case NPP_NOT_SUFFICIENT_COMPUTE_CAPABILITY:
return "NPP_NOT_SUFFICIENT_COMPUTE_CAPABILITY";
#if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) <= 0x5000
case NPP_BAD_ARG_ERROR:
return "NPP_BAD_ARGUMENT_ERROR";
case NPP_COEFF_ERROR:
return "NPP_COEFFICIENT_ERROR";
case NPP_RECT_ERROR:
return "NPP_RECTANGLE_ERROR";
case NPP_QUAD_ERROR:
return "NPP_QUADRANGLE_ERROR";
case NPP_MEM_ALLOC_ERR:
return "NPP_MEMORY_ALLOCATION_ERROR";
case NPP_HISTO_NUMBER_OF_LEVELS_ERROR:
return "NPP_HISTOGRAM_NUMBER_OF_LEVELS_ERROR";
case NPP_INVALID_INPUT:
return "NPP_INVALID_INPUT";
case NPP_POINTER_ERROR:
return "NPP_POINTER_ERROR";
case NPP_WARNING:
return "NPP_WARNING";
case NPP_ODD_ROI_WARNING:
return "NPP_ODD_ROI_WARNING";
#else
// These are for CUDA 5.5 or higher
case NPP_BAD_ARGUMENT_ERROR:
return "NPP_BAD_ARGUMENT_ERROR";
case NPP_COEFFICIENT_ERROR:
return "NPP_COEFFICIENT_ERROR";
case NPP_RECTANGLE_ERROR:
return "NPP_RECTANGLE_ERROR";
case NPP_QUADRANGLE_ERROR:
return "NPP_QUADRANGLE_ERROR";
case NPP_MEMORY_ALLOCATION_ERR:
return "NPP_MEMORY_ALLOCATION_ERROR";
case NPP_HISTOGRAM_NUMBER_OF_LEVELS_ERROR:
return "NPP_HISTOGRAM_NUMBER_OF_LEVELS_ERROR";
case NPP_INVALID_HOST_POINTER_ERROR:
return "NPP_INVALID_HOST_POINTER_ERROR";
case NPP_INVALID_DEVICE_POINTER_ERROR:
return "NPP_INVALID_DEVICE_POINTER_ERROR";
#endif
case NPP_LUT_NUMBER_OF_LEVELS_ERROR:
return "NPP_LUT_NUMBER_OF_LEVELS_ERROR";
case NPP_TEXTURE_BIND_ERROR:
return "NPP_TEXTURE_BIND_ERROR";
case NPP_WRONG_INTERSECTION_ROI_ERROR:
return "NPP_WRONG_INTERSECTION_ROI_ERROR";
case NPP_NOT_EVEN_STEP_ERROR:
return "NPP_NOT_EVEN_STEP_ERROR";
case NPP_INTERPOLATION_ERROR:
return "NPP_INTERPOLATION_ERROR";
case NPP_RESIZE_FACTOR_ERROR:
return "NPP_RESIZE_FACTOR_ERROR";
case NPP_HAAR_CLASSIFIER_PIXEL_MATCH_ERROR:
return "NPP_HAAR_CLASSIFIER_PIXEL_MATCH_ERROR";
#if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) <= 0x5000
case NPP_MEMFREE_ERR:
return "NPP_MEMFREE_ERR";
case NPP_MEMSET_ERR:
return "NPP_MEMSET_ERR";
case NPP_MEMCPY_ERR:
return "NPP_MEMCPY_ERROR";
case NPP_MIRROR_FLIP_ERR:
return "NPP_MIRROR_FLIP_ERR";
#else
case NPP_MEMFREE_ERROR:
return "NPP_MEMFREE_ERROR";
case NPP_MEMSET_ERROR:
return "NPP_MEMSET_ERROR";
case NPP_MEMCPY_ERROR:
return "NPP_MEMCPY_ERROR";
case NPP_MIRROR_FLIP_ERROR:
return "NPP_MIRROR_FLIP_ERROR";
#endif
case NPP_ALIGNMENT_ERROR:
return "NPP_ALIGNMENT_ERROR";
case NPP_STEP_ERROR:
return "NPP_STEP_ERROR";
case NPP_SIZE_ERROR:
return "NPP_SIZE_ERROR";
case NPP_NULL_POINTER_ERROR:
return "NPP_NULL_POINTER_ERROR";
case NPP_CUDA_KERNEL_EXECUTION_ERROR:
return "NPP_CUDA_KERNEL_EXECUTION_ERROR";
case NPP_NOT_IMPLEMENTED_ERROR:
return "NPP_NOT_IMPLEMENTED_ERROR";
case NPP_ERROR:
return "NPP_ERROR";
case NPP_SUCCESS:
return "NPP_SUCCESS";
case NPP_WRONG_INTERSECTION_QUAD_WARNING:
return "NPP_WRONG_INTERSECTION_QUAD_WARNING";
case NPP_MISALIGNED_DST_ROI_WARNING:
return "NPP_MISALIGNED_DST_ROI_WARNING";
case NPP_AFFINE_QUAD_INCORRECT_WARNING:
return "NPP_AFFINE_QUAD_INCORRECT_WARNING";
case NPP_DOUBLE_SIZE_WARNING:
return "NPP_DOUBLE_SIZE_WARNING";
case NPP_WRONG_INTERSECTION_ROI_WARNING:
return "NPP_WRONG_INTERSECTION_ROI_WARNING";
#if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) >= 0x6000
/* These are 6.0 or higher */
case NPP_LUT_PALETTE_BITSIZE_ERROR:
return "NPP_LUT_PALETTE_BITSIZE_ERROR";
case NPP_ZC_MODE_NOT_SUPPORTED_ERROR:
return "NPP_ZC_MODE_NOT_SUPPORTED_ERROR";
case NPP_QUALITY_INDEX_ERROR:
return "NPP_QUALITY_INDEX_ERROR";
case NPP_CHANNEL_ORDER_ERROR:
return "NPP_CHANNEL_ORDER_ERROR";
case NPP_ZERO_MASK_VALUE_ERROR:
return "NPP_ZERO_MASK_VALUE_ERROR";
case NPP_NUMBER_OF_CHANNELS_ERROR:
return "NPP_NUMBER_OF_CHANNELS_ERROR";
case NPP_COI_ERROR:
return "NPP_COI_ERROR";
case NPP_DIVISOR_ERROR:
return "NPP_DIVISOR_ERROR";
case NPP_CHANNEL_ERROR:
return "NPP_CHANNEL_ERROR";
case NPP_STRIDE_ERROR:
return "NPP_STRIDE_ERROR";
case NPP_ANCHOR_ERROR:
return "NPP_ANCHOR_ERROR";
case NPP_MASK_SIZE_ERROR:
return "NPP_MASK_SIZE_ERROR";
case NPP_MOMENT_00_ZERO_ERROR:
return "NPP_MOMENT_00_ZERO_ERROR";
case NPP_THRESHOLD_NEGATIVE_LEVEL_ERROR:
return "NPP_THRESHOLD_NEGATIVE_LEVEL_ERROR";
case NPP_THRESHOLD_ERROR:
return "NPP_THRESHOLD_ERROR";
case NPP_CONTEXT_MATCH_ERROR:
return "NPP_CONTEXT_MATCH_ERROR";
case NPP_FFT_FLAG_ERROR:
return "NPP_FFT_FLAG_ERROR";
case NPP_FFT_ORDER_ERROR:
return "NPP_FFT_ORDER_ERROR";
case NPP_SCALE_RANGE_ERROR:
return "NPP_SCALE_RANGE_ERROR";
case NPP_DATA_TYPE_ERROR:
return "NPP_DATA_TYPE_ERROR";
case NPP_OUT_OFF_RANGE_ERROR:
return "NPP_OUT_OFF_RANGE_ERROR";
case NPP_DIVIDE_BY_ZERO_ERROR:
return "NPP_DIVIDE_BY_ZERO_ERROR";
case NPP_RANGE_ERROR:
return "NPP_RANGE_ERROR";
case NPP_NO_MEMORY_ERROR:
return "NPP_NO_MEMORY_ERROR";
case NPP_ERROR_RESERVED:
return "NPP_ERROR_RESERVED";
case NPP_NO_OPERATION_WARNING:
return "NPP_NO_OPERATION_WARNING";
case NPP_DIVIDE_BY_ZERO_WARNING:
return "NPP_DIVIDE_BY_ZERO_WARNING";
#endif
#if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) >= 0x7000
/* These are 7.0 or higher */
case NPP_OVERFLOW_ERROR:
return "NPP_OVERFLOW_ERROR";
case NPP_CORRUPTED_DATA_ERROR:
return "NPP_CORRUPTED_DATA_ERROR";
#endif
}
return "<unknown>";
}
#endif
template <typename T>
void check(T result, char const *const func, const char *const file,
int const line) {
if (result) {
fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", file, line,
static_cast<unsigned int>(result), _cudaGetErrorEnum(result), func);
exit(EXIT_FAILURE);
}
}
// This will output the proper CUDA error strings in the event
// that a CUDA host call returns an error
#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__)
#endif // HELPER_CUDA_H
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment