Commit 194e19bd authored by baominghelly's avatar baominghelly
Browse files

issue 787 - Merge from main and resolve conflict

parents 147a4ac7 215d1932
......@@ -11,7 +11,7 @@
namespace infinicore {
namespace context {
void setDevice(Device device, bool force_cpu = false);
void setDevice(Device device);
Device getDevice();
size_t getDeviceCount(Device::Type type);
......
......@@ -36,6 +36,10 @@ public:
return cache_vector[device_index];
}
BaseCache &getCache(Device device) {
return getCache(device.getType(), device.getIndex());
}
void setCapacity(size_t capacity) {
capacity_ = capacity;
for (auto &vec : caches_) {
......
......@@ -23,13 +23,13 @@ def get_device_count(device_type):
return _infinicore.get_device_count(infinicore.device(device_type)._underlying.type)
def set_device(device, force_cpu=False):
def set_device(device):
"""Set the current active device.
Args:
device: The device to set as active
"""
_infinicore.set_device(device._underlying, force_cpu)
_infinicore.set_device(device._underlying)
def sync_stream():
......
......@@ -33,15 +33,11 @@ Runtime *ContextImpl::getCpuRuntime() {
return runtime_table_[int(Device::Type::CPU)][0].get();
}
void ContextImpl::setDevice(Device device, bool force_cpu) {
void ContextImpl::setDevice(Device device) {
if (device == getCurrentRuntime()->device()) {
// Do nothing if the device is already set.
return;
}
if (device == Device(Device::Type::CPU, 0) && !force_cpu) {
// if not forced, no need to switch to CPU device runtime
return;
}
if (runtime_table_[int(device.getType())][device.getIndex()] == nullptr) {
// Lazy initialization of runtime if never set before.
......@@ -87,8 +83,8 @@ ContextImpl::ContextImpl() {
namespace context {
void setDevice(Device device, bool force_cpu) {
ContextImpl::singleton().setDevice(device, force_cpu);
void setDevice(Device device) {
ContextImpl::singleton().setDevice(device);
}
Device getDevice() {
......
......@@ -21,7 +21,7 @@ public:
Runtime *getCpuRuntime();
void setDevice(Device, bool force_cpu = false);
void setDevice(Device);
size_t getDeviceCount(Device::Type type);
......
......@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopAddDescriptor_t> caches(
void calculate(Tensor c, Tensor a, Tensor b) {
size_t seed = hash_combine(c, b, a);
auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
auto &cache = caches.getCache(device_type, device_index);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopAddDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor(
context::getInfiniopHandle(c->device()), &desc,
context::getInfiniopHandle(device), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
......
......@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopAttentionDescriptor_t> caches(
void calculate(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) {
size_t seed = hash_combine(out, q, k, v, k_cache, v_cache, pos);
auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
auto &cache = caches.getCache(device_type, device_index);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopAttentionDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAttentionDescriptor(
context::getInfiniopHandle(out->device()), &desc,
context::getInfiniopHandle(device), &desc,
out->desc(), q->desc(), k->desc(), v->desc(),
k_cache->desc(), v_cache->desc(), pos));
cache.put(seed, desc);
......
......@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopCausalSoftmaxDescriptor_t> caches(
void calculate(Tensor output, Tensor input) {
size_t seed = hash_combine(output, input);
auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
auto &cache = caches.getCache(device_type, device_index);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopCausalSoftmaxDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor(
context::getInfiniopHandle(output->device()), &desc,
context::getInfiniopHandle(device), &desc,
output->desc(), input->desc()));
cache.put(seed, desc);
} else {
......
......@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopGemmDescriptor_t> caches(
void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
size_t seed = hash_combine(c, b, a, alpha, beta);
auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
auto &cache = caches.getCache(device_type, device_index);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopGemmDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor(
context::getInfiniopHandle(c->device()), &desc,
context::getInfiniopHandle(device), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
......
......@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopMulDescriptor_t> caches(
void calculate(Tensor c, Tensor a, Tensor b) {
size_t seed = hash_combine(c, b, a);
auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
auto &cache = caches.getCache(device_type, device_index);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopMulDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor(
context::getInfiniopHandle(c->device()), &desc,
context::getInfiniopHandle(device), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
......
......@@ -25,17 +25,15 @@ static void calculate(
// cache per (result desc + logits desc) on device
size_t seed = hash_combine(indices, logits);
auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
auto &cache = caches.getCache(device_type, device_index);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopRandomSampleDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRandomSampleDescriptor(
context::getInfiniopHandle(indices->device()), &desc,
context::getInfiniopHandle(device), &desc,
indices->desc(), logits->desc()));
cache.put(seed, desc);
} else {
......
......@@ -18,16 +18,14 @@ thread_local common::OpCache<size_t, infiniopRearrangeDescriptor_t> caches(
void calculate(Tensor y, Tensor x) {
size_t seed = hash_combine(y, x);
auto device_type = y->device().getType();
auto device_index = y->device().getIndex();
auto &cache = caches.getCache(device_type, device_index);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopRearrangeDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(y->device()), &desc, y->desc(), x->desc()));
INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(device), &desc, y->desc(), x->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
......
......@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopRMSNormDescriptor_t> caches(
void calculate(Tensor y, Tensor x, Tensor weight, float epsilon) {
size_t seed = hash_combine(y, x, weight, epsilon);
auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
auto &cache = caches.getCache(device_type, device_index);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopRMSNormDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRMSNormDescriptor(
context::getInfiniopHandle(y->device()), &desc,
context::getInfiniopHandle(device), &desc,
y->desc(), x->desc(), weight->desc(), epsilon));
cache.put(seed, desc);
} else {
......
......@@ -33,16 +33,15 @@ void calculate(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &s
size_t key = hash_combine(x_out, x, pos, sin_cache, cos_cache);
hash_combine(key, std::hash<int>()(static_cast<int>(infiniop_algo)));
auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
auto &cache = caches.getCache(device_type, device_index);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(key);
infiniopRoPEDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRoPEDescriptor(
context::getInfiniopHandle(x_out->device()), &desc,
context::getInfiniopHandle(device), &desc,
x_out->desc(), x->desc(),
pos->desc(), sin_cache->desc(), cos_cache->desc(),
infiniop_algo));
......
......@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopSiluDescriptor_t> caches(
void calculate(Tensor output, Tensor input) {
size_t seed = hash_combine(output, input);
auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
auto &cache = caches.getCache(device_type, device_index);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopSiluDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateSiluDescriptor(
context::getInfiniopHandle(output->device()), &desc,
context::getInfiniopHandle(device), &desc,
output->desc(), input->desc()));
cache.put(seed, desc);
} else {
......
......@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopSwiGLUDescriptor_t> caches(
void calculate(Tensor c, Tensor a, Tensor b) {
size_t seed = hash_combine(c, b, a);
auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
auto &cache = caches.getCache(device_type, device_index);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopSwiGLUDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateSwiGLUDescriptor(
context::getInfiniopHandle(c->device()), &desc,
context::getInfiniopHandle(device), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
......
......@@ -16,8 +16,7 @@ inline void bind(py::module &m) {
py::arg("device_type"));
m.def("set_device", &setDevice,
"Set the current active device",
py::arg("device"),
py::arg("force_cpu"));
py::arg("device"));
// Stream and handle management
m.def("get_stream", &getStream, "Get the current stream");
......
......@@ -31,6 +31,7 @@ void TensorImpl::copy_from(Tensor src) {
// Use nbytes() to get the actual tensor size, not the full memory size
size_t copy_size = std::min(this->nbytes(), src->nbytes());
if (this->device().getType() == Device::Type::CPU) {
context::setDevice(src->device());
if (this->is_contiguous()) {
context::memcpyD2H(this->data(), src->data(), copy_size);
} else {
......@@ -39,7 +40,7 @@ void TensorImpl::copy_from(Tensor src) {
op::rearrange_(Tensor(const_cast<TensorImpl *>(this)->shared_from_this()), local_src);
}
} else if (src->device().getType() == Device::Type::CPU) {
context::setDevice(this->device());
if (this->is_contiguous()) {
context::memcpyH2D(this->data(), src->data(), copy_size);
} else {
......
#ifndef _TOPKROUTER_KERNEL_CUH__
#define _TOPKROUTER_KERNEL_CUH__
#include <cfloat>
#include <cub/block/block_load.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/block/block_store.cuh>
#include <cub/cub.cuh>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
template <typename T>
inline __device__ float exp_func(T x) {
float data;
if constexpr (std::is_same_v<T, float>) {
data = x;
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
data = __bfloat162float(x);
} else if constexpr (std::is_same_v<T, half>) {
data = __half2float(x);
......
#ifndef __TOPKROUTER_METAX_H__
#define __TOPKROUTER_METAX_H__
#include "../topkrouter.h"
DESCRIPTOR(metax)
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment