Unverified Commit 784139b9 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #990 from InfiniTensor/demo131

Demo-131 Cuda graph with optimized paged attention
parents 3c8fb3c0 1d6527cb
#include "infinicore/ops/scaled_mm_i8.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(I8Gemm);
I8Gemm::I8Gemm(Tensor c, const Tensor &a_p, const Tensor &a_s, const Tensor &b_p, const Tensor &b_s, std::optional<Tensor> bias) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a_p, a_s, b_p, b_s);
INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a_p, a_s, b_p, b_s, bias);
}
void I8Gemm::execute(Tensor c, const Tensor &a_p, const Tensor &a_s, const Tensor &b_p, const Tensor &b_s, std::optional<Tensor> bias) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(I8Gemm, c, a_p, a_s, b_p, b_s, bias);
}
void scaled_mm_i8_(Tensor c, const Tensor &a_p, const Tensor &a_s, const Tensor &b_p, const Tensor &b_s, std::optional<Tensor> bias) {
I8Gemm::execute(c, a_p, a_s, b_p, b_s, bias);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "../infiniop_impl.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/scaled_mm_i8.hpp"
#include <infiniop.h>
namespace infinicore::op::scaled_mm_i8_impl::infiniop {
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, I8Gemm, 100);
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, c, a_p, a_s, b_p, b_s;
std::optional<graph::GraphTensor> bias;
};
void *plan(Tensor c, const Tensor &a_p, const Tensor &a_s, const Tensor &b_p, const Tensor &b_s, std::optional<Tensor> bias) {
size_t seed = hash_combine(c, a_p, a_s, b_p, b_s);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, I8Gemm,
seed,
c->desc(), bias.has_value() ? bias.value()->desc() : nullptr,
a_p->desc(), a_s->desc(), b_p->desc(), b_s->desc());
INFINIOP_WORKSPACE_TENSOR(workspace, I8Gemm, descriptor);
return new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(c),
graph::GraphTensor(a_p),
graph::GraphTensor(a_s),
graph::GraphTensor(b_p),
graph::GraphTensor(b_s),
// bias.has_value() ? bias.value()->desc() : nullptr};
bias ? std::optional<graph::GraphTensor>(graph::GraphTensor(*bias)) : std::nullopt};
}
void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(infiniopI8Gemm(
planned->descriptor->desc,
planned->workspace->data(),
planned->workspace->numel(),
planned->c->data(),
// planned->bias->data(),
planned->bias.has_value() ? planned->bias.value()->data() : nullptr,
planned->a_p->data(),
planned->a_s->data(),
planned->b_p->data(),
planned->b_s->data(),
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(I8Gemm, &plan, &run, &cleanup);
} // namespace infinicore::op::scaled_mm_i8_impl::infiniop
#include "infinicore/ops/silu_and_mul.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(SiluAndMul);
SiluAndMul::SiluAndMul(Tensor out, const Tensor &x) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, x);
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), out, x);
}
void SiluAndMul::execute(Tensor out, const Tensor &x) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(SiluAndMul, out, x);
}
Tensor silu_and_mul(const Tensor &x) {
Shape shape = x->shape();
size_t ndim = x->ndim();
if (shape[ndim - 1] % 2 != 0) {
throw std::runtime_error("SiluAndMul input last dim must be even.");
}
shape[ndim - 1] /= 2;
auto out = Tensor::empty(shape, x->dtype(), x->device());
silu_and_mul_(out, x);
return out;
}
void silu_and_mul_(Tensor out, const Tensor &x) {
SiluAndMul::execute(out, x);
}
} // namespace infinicore::op
#include "../infiniop_impl.hpp"
#include "infinicore/ops/silu_and_mul.hpp"
namespace infinicore::op::silu_and_mul_impl::infiniop {
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, SiluAndMul, 100);
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, output, input;
};
void *plan(Tensor output, const Tensor &input) {
size_t seed = hash_combine(output, input);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, SiluAndMul,
seed, output->desc(), input->desc());
INFINIOP_WORKSPACE_TENSOR(workspace, SiluAndMul, descriptor);
auto planned = new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(output),
graph::GraphTensor(input)};
return planned;
}
void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(infiniopSiluAndMul(
planned->descriptor->desc,
planned->workspace->data(),
planned->workspace->numel(),
planned->output->data(),
planned->input->data(),
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(SiluAndMul, &plan, &run, &cleanup);
} // namespace infinicore::op::silu_and_mul_impl::infiniop
#include "infinicore/ops/swiglu.hpp"
#include "../../utils.hpp"
#include <stdexcept>
namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(SwiGLU);
common::OpDispatcher<SwiGLU::schema> &SwiGLU::dispatcher() {
static common::OpDispatcher<SwiGLU::schema> dispatcher_;
return dispatcher_;
};
void SwiGLU::execute(Tensor c, Tensor a, Tensor b) {
SwiGLU::SwiGLU(Tensor c, const Tensor &a, const Tensor &b) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b);
infinicore::context::setDevice(c->device());
auto device_type = c->device().getType();
auto func = dispatcher().lookup(device_type);
if (func == nullptr) {
throw std::runtime_error("No SwiGLU implementation found for device type: " + std::to_string(static_cast<int>(device_type)));
}
INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b);
}
func(c, a, b);
void SwiGLU::execute(Tensor c, const Tensor &a, const Tensor &b) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(SwiGLU, c, a, b);
}
Tensor swiglu(Tensor a, Tensor b) {
Shape shape = a->shape();
auto c = Tensor::empty(shape, a->dtype(), a->device());
Tensor swiglu(const Tensor &a, const Tensor &b) {
auto c = Tensor::empty(a->shape(), a->dtype(), a->device());
swiglu_(c, a, b);
return c;
}
void swiglu_(Tensor c, Tensor a, Tensor b) {
void swiglu_(Tensor c, const Tensor &a, const Tensor &b) {
SwiGLU::execute(c, a, b);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/swiglu.hpp"
#include <infiniop.h>
#include "../infiniop_impl.hpp"
namespace infinicore::op::swiglu_impl::infiniop {
thread_local common::OpCache<size_t, infiniopSwiGLUDescriptor_t> caches(
100, // capacity
[](infiniopSwiGLUDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroySwiGLUDescriptor(desc));
desc = nullptr;
}
});
void calculate(Tensor c, Tensor a, Tensor b) {
size_t seed = hash_combine(c, b, a);
auto device = context::getDevice();
auto &cache = caches.getCache(device);
auto desc_opt = cache.get(seed);
infiniopSwiGLUDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateSwiGLUDescriptor(
context::getInfiniopHandle(device), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetSwiGLUWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopSwiGLU(
desc, workspace->data(), workspace_size,
c->data(), a->data(), b->data(), context::getStream()));
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, SwiGLU, 100);
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace;
graph::GraphTensor c;
graph::GraphTensor a;
graph::GraphTensor b;
};
void *plan(Tensor c, const Tensor &a, const Tensor &b) {
size_t key = hash_combine(c, a, b);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, SwiGLU,
key, c->desc(), a->desc(), b->desc());
INFINIOP_WORKSPACE_TENSOR(workspace, SwiGLU, descriptor);
return new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(c),
graph::GraphTensor(a),
graph::GraphTensor(b)};
}
void run(void *planned_meta) {
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
INFINICORE_CHECK_ERROR(
infiniopSwiGLU(
p->descriptor->desc,
p->workspace->data(),
p->workspace->numel(),
p->c->data(),
p->a->data(),
p->b->data(),
context::getStream()));
}
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}
static bool registered = []() {
SwiGLU::dispatcher().registerAll(&calculate, false);
return true;
}();
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(SwiGLU, &plan, &run, &cleanup);
} // namespace infinicore::op::swiglu_impl::infiniop
......@@ -22,6 +22,7 @@ inline void bind(py::module &m) {
.value("QY", Device::Type::QY)
.value("KUNLUN", Device::Type::KUNLUN)
.value("HYGON", Device::Type::HYGON)
.value("ALI", Device::Type::ALI)
.value("COUNT", Device::Type::COUNT);
device
......
......@@ -7,7 +7,10 @@
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/embedding.hpp"
#include "ops/flash_attention.hpp"
#include "ops/kv_caching.hpp"
#include "ops/linear.hpp"
#include "ops/linear_w8a8i8.hpp"
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
#include "ops/paged_attention.hpp"
......@@ -18,6 +21,7 @@
#include "ops/rms_norm.hpp"
#include "ops/rope.hpp"
#include "ops/silu.hpp"
#include "ops/silu_and_mul.hpp"
#include "ops/swiglu.hpp"
namespace py = pybind11;
......@@ -29,19 +33,23 @@ inline void bind(py::module &m) {
bind_add_rms_norm(m);
bind_attention(m);
bind_causal_softmax(m);
bind_random_sample(m);
bind_flash_attention(m);
bind_kv_caching(m);
bind_linear(m);
bind_matmul(m);
bind_mul(m);
bind_paged_attention(m);
bind_paged_attention_prefill(m);
bind_paged_caching(m);
bind_random_sample(m);
bind_rearrange(m);
bind_rms_norm(m);
bind_silu(m);
bind_swiglu(m);
bind_rope(m);
bind_embedding(m);
bind_linear_w8a8i8(m);
bind_silu_and_mul(m);
}
} // namespace infinicore::ops
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/flash_attention.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_flash_attention(py::module &m) {
m.def("flash_attention",
&op::flash_attention,
py::arg("q"),
py::arg("k"),
py::arg("v"),
py::arg("total_kv_len"),
py::arg("scale"),
py::arg("is_causal"));
}
} // namespace infinicore::ops
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/kv_caching.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_kv_caching(py::module &m) {
m.def("kv_caching_",
&op::kv_caching_,
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("k"),
py::arg("v"),
py::arg("past_kv_lengths"),
R"doc(In-place Key-Value Caching.
Updates the KV cache in-place with new key and value tensors.
Args:
k_cache: Key cache tensor to update in-place
v_cache: Value cache tensor to update in-place
k: New key tensor to append
v: New value tensor to append
past_kv_lengths: Tensor containing current sequence lengths for each batch
)doc");
}
} // namespace infinicore::ops
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/linear_w8a8i8.hpp"
namespace py = pybind11;
namespace infinicore::ops {
Tensor py_linear_w8a8i8(Tensor input,
Tensor weight_packed,
Tensor weight_scale,
pybind11::object bias) {
std::optional<Tensor> bias_tensor = std::nullopt;
if (!bias.is_none()) {
bias_tensor = bias.cast<Tensor>();
}
return op::linear_w8a8i8(input, weight_packed, weight_scale, bias_tensor);
}
void py_linear_w8a8i8_(Tensor out,
Tensor input,
Tensor weight_packed,
Tensor weight_scale,
pybind11::object bias) {
std::optional<Tensor> bias_tensor = std::nullopt;
if (!bias.is_none()) {
bias_tensor = bias.cast<Tensor>();
}
op::linear_w8a8i8_(out, input, weight_packed, weight_scale, bias_tensor);
}
inline void bind_linear_w8a8i8(py::module &m) {
m.def("linear_w8a8i8",
&ops::py_linear_w8a8i8,
py::arg("input"),
py::arg("weight_packed"),
py::arg("weight_scale"),
py::arg("bias") = py::none(),
R"doc(linear_w8a8i8.)doc");
m.def("linear_w8a8i8_",
&ops::py_linear_w8a8i8_,
py::arg("out"),
py::arg("input"),
py::arg("weight_packed"),
py::arg("weight_scale"),
py::arg("bias") = py::none(),
R"doc(linear_w8a8i8_.)doc");
}
} // namespace infinicore::ops
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/per_channel_quant_i8.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_per_channel_quant_i8(py::module &m) {
m.def("per_channel_quant_i8_",
&op::per_channel_quant_i8_,
py::arg("x"),
py::arg("x_packed"),
py::arg("x_scale"),
R"doc(Per-channel quantization of a tensor.)doc");
}
} // namespace infinicore::ops
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/scaled_mm_i8.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_scaled_mm_i8(py::module &m) {
m.def("scaled_mm_i8",
&op::scaled_mm_i8,
py::arg("a_p"),
py::arg("a_s"),
py::arg("b_p"),
py::arg("b_s"),
py::arg("bias"),
R"doc(Scaled matrix multiplication of two tensors.)doc");
m.def("scaled_mm_i8_",
&op::scaled_mm_i8_,
py::arg("a"),
py::arg("b"),
py::arg("a_scale"),
py::arg("b_scale"),
R"doc(In-place Scaled matrix multiplication of two tensors.)doc");
}
} // namespace infinicore::ops
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/silu_and_mul.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_silu_and_mul(py::module &m) {
m.def("silu_and_mul",
&op::silu_and_mul,
py::arg("input"),
R"doc(
SiLU and Mul (SwiGLU) activation function.
Input should be [..., 2*d], output will be [..., d].
)doc");
m.def("silu_and_mul_",
&op::silu_and_mul_,
py::arg("output"),
py::arg("input"),
R"doc(
In-place or destination-specified SiLU and Mul (SwiGLU) activation function.
)doc");
}
} // namespace infinicore::ops
......@@ -95,6 +95,20 @@ void print_data_bf16(const uint16_t *data, const Shape &shape, const Strides &st
}
}
// Function for printing I8 data
void print_data_i8(const int8_t *data, const Shape &shape, const Strides &strides, size_t dim) {
if (dim == shape.size() - 1) {
for (size_t i = 0; i < shape[dim]; i++) {
std::cout << static_cast<int>(data[i * strides[dim]]) << " ";
}
std::cout << std::endl;
} else if (dim < shape.size() - 1) {
for (size_t i = 0; i < shape[dim]; i++) {
print_data_i8(data + i * strides[dim], shape, strides, dim + 1);
}
}
}
// Template function for writing data recursively to binary file (handles non-contiguous tensors)
template <typename T>
void write_binary_data(std::ofstream &out, const T *data, const Shape &shape, const Strides &strides, size_t dim) {
......@@ -181,8 +195,8 @@ void TensorImpl::debug(const std::string &filename) const {
cpu_tensor->shape(), cpu_tensor->strides(), 0);
break;
case DataType::I8:
print_data(reinterpret_cast<const int8_t *>(cpu_data),
cpu_tensor->shape(), cpu_tensor->strides(), 0);
print_data_i8(reinterpret_cast<const int8_t *>(cpu_data),
cpu_tensor->shape(), cpu_tensor->strides(), 0);
break;
case DataType::BF16:
print_data_bf16(reinterpret_cast<const uint16_t *>(cpu_data),
......
......@@ -2,6 +2,8 @@
#include "infinicore/dtype.hpp"
#include "infinicore/tensor.hpp"
#include "../utils.hpp"
#include <spdlog/spdlog.h>
#include <stdexcept>
......@@ -62,11 +64,11 @@ Tensor TensorImpl::narrow(const std::vector<TensorSliceParams> &slices) const {
Tensor TensorImpl::permute(const Shape &order) const {
// Validate input
assert(meta_.shape.size() == order.size());
INFINICORE_ASSERT(meta_.shape.size() == order.size());
// Check that order contains all indices from 0 to n-1 exactly once
for (size_t i = 0; i < order.size(); i++) {
assert(std::find(order.begin(), order.end(), i) != order.end());
INFINICORE_ASSERT(std::find(order.begin(), order.end(), i) != order.end());
}
// Permute shape and strides
......
......@@ -22,7 +22,7 @@ void printUsage() {
std::cout << " Path to the test gguf file" << std::endl
<< std::endl;
std::cout << " --<device>[:id]" << std::endl;
std::cout << " (Optional) Specify the device type --(cpu|nvidia|cambricon|ascend|metax|moore|iluvatar|qy|kunlun|hygon) and device ID (optional). CPU by default." << std::endl
std::cout << " (Optional) Specify the device type --(cpu|nvidia|cambricon|ascend|metax|moore|iluvatar|qy|kunlun|hygon|ali) and device ID (optional). CPU by default." << std::endl
<< std::endl;
std::cout << " --warmup <warmups>" << std::endl;
std::cout << " (Optional) Number of warmups to perform before timing. Default to 0." << std::endl
......@@ -80,6 +80,7 @@ ParsedArgs parseArgs(int argc, char *argv[]) {
PARSE_DEVICE("--qy", INFINI_DEVICE_QY)
PARSE_DEVICE("--kunlun", INFINI_DEVICE_KUNLUN)
PARSE_DEVICE("--hygon", INFINI_DEVICE_HYGON)
PARSE_DEVICE("--ali", INFINI_DEVICE_ALI)
else if (arg == "--warmup" && i + 1 < argc) {
args.warmups = std::stoi(argv[++i]);
}
......
......@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/cpu_handle.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) || defined(ENABLE_ALI_API)
#include "nvidia/nvidia_handle.h"
#endif
#ifdef ENABLE_CAMBRICON_API
......@@ -47,6 +47,9 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
#ifdef ENABLE_ILUVATAR_API
CREATE(INFINI_DEVICE_ILUVATAR, iluvatar);
#endif
#ifdef ENABLE_ALI_API
CREATE(INFINI_DEVICE_ALI, ali);
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, qy);
#endif
......@@ -93,6 +96,9 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
#ifdef ENABLE_ILUVATAR_API
DELETE(INFINI_DEVICE_ILUVATAR, iluvatar);
#endif
#ifdef ENABLE_ALI_API
DELETE(INFINI_DEVICE_ALI, ali);
#endif
#ifdef ENABLE_QY_API
DELETE(INFINI_DEVICE_QY, qy);
#endif
......
......@@ -85,4 +85,20 @@
#define hcclSuccess mcclSuccess
#define hcclCommDestroy mcclCommDestroy
#define hcclAllReduce mcclAllReduce
#define hcGetDevice mcGetDevice
#define hcDeviceAttributeMultiProcessorCount mcDeviceAttributeMultiProcessorCount
#define hcDeviceGetAttribute mcDeviceGetAttribute
#define hcStreamCaptureMode mcStreamCaptureMode
#define hcStreamCaptureModeGlobal mcStreamCaptureModeGlobal
#define hcStreamCaptureModeThreadLocal mcStreamCaptureModeThreadLocal
#define hcStreamCaptureModeRelaxed mcStreamCaptureModeRelaxed
#define hcStreamBeginCapture mcStreamBeginCapture
#define hcStreamEndCapture mcStreamEndCapture
#define hcGraph_t mcGraph_t
#define hcGraphExec_t mcGraphExec_t
#define hcGraphNode_t mcGraphNode_t
#define hcGraphInstantiate mcGraphInstantiate
#define hcGraphDestroy mcGraphDestroy
#define hcGraphExecDestroy mcGraphExecDestroy
#define hcGraphLaunch mcGraphLaunch
#endif
......@@ -8,8 +8,10 @@
// Posible maximum number of threads per block for METAX architectures
// Used for picking correct kernel launch configuration
#define METAX_BLOCK_SIZE_1024 1024
#define METAX_BLOCK_SIZE_512 512
#define METAX_BLOCK_SIZE_1024 1024
#define METAX_BLOCK_SIZE_2048 2048
#define METAX_BLOCK_SIZE_4096 4096
#define CHECK_METAX(API) CHECK_INTERNAL(API, hcSuccess)
......@@ -17,6 +19,12 @@ using cuda_bfloat16 = hpcc_bfloat16;
using cuda_bfloat162 = hpcc_bfloat162;
using cuda_fp8_e4m3 = __hpcc_fp8_e4m3;
#ifdef ENABLE_METAX_MC_API
using __nv_bfloat16 = __maca_bfloat16;
#else
using __nv_bfloat16 = __hpcc_bfloat16;
#endif
namespace device::metax {
// get the memory offset of the given element in a tensor given its flat index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment