Commit b9e12416 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.3

parents e5d707db e9d3aa04
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void rms_norm_impl(scalar_t *__restrict__ out, void rms_norm_impl(scalar_t* __restrict__ out,
const scalar_t *__restrict__ input, const scalar_t* __restrict__ input,
const scalar_t *__restrict__ weight, const float epsilon, const scalar_t* __restrict__ weight, const float epsilon,
const int num_tokens, const int hidden_size) { const int num_tokens, const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
...@@ -41,9 +41,9 @@ void rms_norm_impl(scalar_t *__restrict__ out, ...@@ -41,9 +41,9 @@ void rms_norm_impl(scalar_t *__restrict__ out,
} }
template <typename scalar_t> template <typename scalar_t>
void fused_add_rms_norm_impl(scalar_t *__restrict__ input, void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
scalar_t *__restrict__ residual, scalar_t* __restrict__ residual,
const scalar_t *__restrict__ weight, const scalar_t* __restrict__ weight,
const float epsilon, const int num_tokens, const float epsilon, const int num_tokens,
const int hidden_size) { const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
...@@ -87,8 +87,8 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input, ...@@ -87,8 +87,8 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
} }
} // namespace } // namespace
void rms_norm(torch::Tensor &out, torch::Tensor &input, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor &weight, float epsilon) { float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
...@@ -101,8 +101,8 @@ void rms_norm(torch::Tensor &out, torch::Tensor &input, ...@@ -101,8 +101,8 @@ void rms_norm(torch::Tensor &out, torch::Tensor &input,
}); });
} }
void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor &weight, float epsilon) { torch::Tensor& weight, float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
......
...@@ -4,22 +4,21 @@ ...@@ -4,22 +4,21 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void rotary_embedding_impl( void rotary_embedding_impl(
const int64_t const int64_t* __restrict__ positions, // [batch_size, seq_len] or
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens] // [num_tokens]
scalar_t scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or /// head_size] or [num_tokens, num_heads,
/// [num_tokens, num_heads, head_size] /// head_size]
scalar_t scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or // head_size] or [num_tokens, num_kv_heads,
// [num_tokens, num_kv_heads, head_size] // head_size]
const scalar_t const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size, const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) { const int num_tokens) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
constexpr int ELEM_SIZE = sizeof(scalar_t);
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0);
...@@ -27,7 +26,7 @@ void rotary_embedding_impl( ...@@ -27,7 +26,7 @@ void rotary_embedding_impl(
#pragma omp parallel for #pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
for (int i = 0; i < num_heads; ++i) { for (int i = 0; i < num_heads; ++i) {
const int head_idx = i; const int head_idx = i;
...@@ -95,16 +94,16 @@ void rotary_embedding_impl( ...@@ -95,16 +94,16 @@ void rotary_embedding_impl(
template <typename scalar_t> template <typename scalar_t>
void rotary_embedding_gptj_impl( void rotary_embedding_gptj_impl(
const int64_t const int64_t* __restrict__ positions, // [batch_size, seq_len] or
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens] // [num_tokens]
scalar_t scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or /// head_size] or [num_tokens, num_heads,
/// [num_tokens, num_heads, head_size] /// head_size]
scalar_t scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or // head_size] or [num_tokens, num_kv_heads,
// [num_tokens, num_kv_heads, head_size] // head_size]
const scalar_t const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size, const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) { const int num_tokens) {
...@@ -114,13 +113,13 @@ void rotary_embedding_gptj_impl( ...@@ -114,13 +113,13 @@ void rotary_embedding_gptj_impl(
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_heads; ++i) { for (int i = 0; i < num_heads; ++i) {
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t *cos_cache_ptr = cache_ptr; const scalar_t* cos_cache_ptr = cache_ptr;
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i; const int head_idx = i;
const int64_t token_head = const int64_t token_head =
token_idx * query_stride + head_idx * head_size; token_idx * query_stride + head_idx * head_size;
scalar_t *head_query = token_head + query; scalar_t* head_query = token_head + query;
for (int j = 0; j < embed_dim; j += 1) { for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j; const int rot_offset = j;
const int x_index = 2 * rot_offset; const int x_index = 2 * rot_offset;
...@@ -142,12 +141,12 @@ void rotary_embedding_gptj_impl( ...@@ -142,12 +141,12 @@ void rotary_embedding_gptj_impl(
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_kv_heads; ++i) { for (int i = 0; i < num_kv_heads; ++i) {
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t *cos_cache_ptr = cache_ptr; const scalar_t* cos_cache_ptr = cache_ptr;
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i; const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int64_t token_head = token_idx * key_stride + head_idx * head_size;
scalar_t *head_key = key + token_head; scalar_t* head_key = key + token_head;
for (int j = 0; j < embed_dim; j += 1) { for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j; const int rot_offset = j;
const int x_index = 2 * rot_offset; const int x_index = 2 * rot_offset;
...@@ -167,9 +166,9 @@ void rotary_embedding_gptj_impl( ...@@ -167,9 +166,9 @@ void rotary_embedding_gptj_impl(
} }
}; // namespace }; // namespace
void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor &key, int head_size, torch::Tensor& key, int head_size,
torch::Tensor &cos_sin_cache, bool is_neox) { torch::Tensor& cos_sin_cache, bool is_neox) {
int num_tokens = query.numel() / query.size(-1); int num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size; int num_heads = query.size(-1) / head_size;
......
...@@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops // Attention ops
ops.def( ops.def("paged_attention_v1", &paged_attention_v1,
"paged_attention_v1", "Compute the attention between an input query and the cached "
&paged_attention_v1, "keys/values using PagedAttention.");
"Compute the attention between an input query and the cached keys/values using PagedAttention."); ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
// Activation ops // Activation ops
ops.def( ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
"silu_and_mul", ops.def("gelu_and_mul", &gelu_and_mul,
&silu_and_mul,
"Activation function used in SwiGLU.");
ops.def(
"gelu_and_mul",
&gelu_and_mul,
"Activation function used in GeGLU with `none` approximation."); "Activation function used in GeGLU with `none` approximation.");
ops.def( ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation."); "Activation function used in GeGLU with `tanh` approximation.");
ops.def( ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
"gelu_new", ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
// Layernorm // Layernorm
ops.def( ops.def("rms_norm", &rms_norm,
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor."); "Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def( ops.def("fused_add_rms_norm", &fused_add_rms_norm,
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization"); "In-place fused Add and RMS Normalization");
// Rotary embedding // Rotary embedding
ops.def( ops.def("rotary_embedding", &rotary_embedding,
"rotary_embedding",
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
// Cache ops // Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def( cache_ops.def("swap_blocks", &swap_blocks,
"swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst"); "Swap in (out) the cache blocks from src to dst");
cache_ops.def( cache_ops.def("copy_blocks", &copy_blocks,
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst"); "Copy the cache blocks from src to dst");
cache_ops.def( cache_ops.def("reshape_and_cache", &reshape_and_cache,
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them"); "Reshape the key and value tensors and cache them");
} }
#pragma once #pragma once
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
#else #else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#endif #endif
...@@ -28,6 +29,13 @@ ...@@ -28,6 +29,13 @@
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
#endif #endif
#ifndef USE_ROCM
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
__shfl_down_sync(uint32_t(-1), var, lane_delta)
#else
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
#endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
...@@ -35,4 +43,3 @@ ...@@ -35,4 +43,3 @@
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif #endif
...@@ -2,9 +2,6 @@ ...@@ -2,9 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
int get_device_attribute( int get_device_attribute(int attribute, int device_id);
int attribute,
int device_id);
int get_max_shared_memory_per_block_device_attribute( int get_max_shared_memory_per_block_device_attribute(int device_id);
int device_id);
...@@ -2,28 +2,22 @@ ...@@ -2,28 +2,22 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#endif #endif
int get_device_attribute( int get_device_attribute(int attribute, int device_id) {
int attribute,
int device_id)
{
int device, value; int device, value;
if (device_id < 0) { if (device_id < 0) {
cudaGetDevice(&device); cudaGetDevice(&device);
} } else {
else {
device = device_id; device = device_id;
} }
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device); cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
device);
return value; return value;
} }
int get_max_shared_memory_per_block_device_attribute(int device_id) {
int get_max_shared_memory_per_block_device_attribute( int attribute;
int device_id) // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
{ // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
int attribute;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
#ifdef USE_ROCM #ifdef USE_ROCM
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -12,8 +12,7 @@ ...@@ -12,8 +12,7 @@
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
...@@ -22,8 +21,8 @@ ...@@ -22,8 +21,8 @@
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
...@@ -33,5 +32,4 @@ ...@@ -33,5 +32,4 @@
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
This diff is collapsed.
...@@ -3,5 +3,6 @@ ...@@ -3,5 +3,6 @@
#include <torch/extension.h> #include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); m.def("topk_softmax", &topk_softmax,
"Apply topk softmax to the gating outputs.");
} }
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
void topk_softmax( void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices, torch::Tensor& token_expert_indices,
torch::Tensor& gating_output); torch::Tensor& gating_output);
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 2752) \ f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2816) \ f(in_T, out_T, W_T, narrow, 2816) \
f(in_T, out_T, W_T, narrow, 3072) \ f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3328) \
f(in_T, out_T, W_T, narrow, 3456) \ f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \ f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \ f(in_T, out_T, W_T, narrow, 4096) \
...@@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 5504) \ f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \ f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \ f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6400) \
f(in_T, out_T, W_T, narrow, 6848) \ f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \ f(in_T, out_T, W_T, narrow, 7168) \
...@@ -53,6 +55,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -53,6 +55,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 22016) \ f(in_T, out_T, W_T, narrow, 22016) \
f(in_T, out_T, W_T, narrow, 24576) \ f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 27392) \ f(in_T, out_T, W_T, narrow, 27392) \
f(in_T, out_T, W_T, narrow, 27648) \
f(in_T, out_T, W_T, narrow, 28672) \ f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \ f(in_T, out_T, W_T, narrow, 32256) \
...@@ -96,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -96,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 2752, narrow) \ f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \ f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \ f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3328, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \ f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \ f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \ f(in_T, out_T, W_T, 4096, narrow) \
...@@ -104,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -104,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 5504, narrow) \ f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \ f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \ f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6400, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \ f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \ f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \ f(in_T, out_T, W_T, 7168, narrow) \
...@@ -121,6 +126,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ...@@ -121,6 +126,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 22016, narrow) \ f(in_T, out_T, W_T, 22016, narrow) \
f(in_T, out_T, W_T, 24576, narrow) \ f(in_T, out_T, W_T, 24576, narrow) \
f(in_T, out_T, W_T, 27392, narrow) \ f(in_T, out_T, W_T, 27392, narrow) \
f(in_T, out_T, W_T, 27648, narrow) \
f(in_T, out_T, W_T, 28672, narrow) \ f(in_T, out_T, W_T, 28672, narrow) \
f(in_T, out_T, W_T, 32000, narrow) \ f(in_T, out_T, W_T, 32000, narrow) \
f(in_T, out_T, W_T, 32256, narrow) \ f(in_T, out_T, W_T, 32256, narrow) \
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment