Unverified Commit dc8ddd58 authored by Ziminli's avatar Ziminli Committed by GitHub
Browse files

issue/388: Support 3D Cases for RMS Norm

parent 47895fae
......@@ -21,19 +21,27 @@ infiniStatus_t Descriptor::create(
template <typename T>
infiniStatus_t rmsnorm(const RMSNormInfo *info, T *y, const T *x, const T *w) {
const size_t batch_size = info->shape[0];
const size_t nhead = info->shape.size() > 2 ? info->shape[1] : 1;
const size_t dim = info->shape.back();
const ptrdiff_t total_blocks = static_cast<ptrdiff_t>(batch_size * nhead);
#pragma omp parallel for
for (ptrdiff_t i = 0; i < ptrdiff_t(info->shape[0]); i++) {
T *x_ = (T *)(x + i * info->x_strides[0]);
T *y_ = (T *)(y + i * info->y_strides[0]);
for (ptrdiff_t block_idx = 0; block_idx < total_blocks; ++block_idx) {
const size_t i = block_idx / nhead; // batch index
const size_t j = block_idx % nhead; // head index
const T *x_ptr = x + i * info->x_strides[0] + j * info->x_strides[1];
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
// [Reduce] sum of x^2 on last dimension
T ss = op::common_cpu::reduce_op::sumSquared(x_, info->shape[1], info->x_strides[1]);
T ss = op::common_cpu::reduce_op::sumSquared(x_ptr, dim, info->x_strides.back());
// 1 / (sqrt(sum/dim + eps))
T rms = (T)1 / std::sqrt(ss / (T)(info->shape[1]) + (T)(info->epsilon));
T rms = (T)1 / std::sqrt(ss / (T)(dim) + (T)(info->epsilon));
for (size_t j = 0; j < info->shape[1]; j++) {
y_[j * info->y_strides[1]] = x_[j * info->x_strides[1]] * w[j] * rms;
for (size_t k = 0; k < dim; k++) {
y_ptr[k] = x_ptr[k] * w[k] * rms;
}
}
......@@ -45,24 +53,32 @@ infiniStatus_t rmsnormHalfPrecision(const RMSNormInfo *info, T *y, const T *x, c
static_assert(std::is_same<T, fp16_t>::value || std::is_same<T, bf16_t>::value,
"T must be fp16_t or bf16_t");
const size_t batch_size = info->shape[0];
const size_t nhead = info->shape.size() > 2 ? info->shape[1] : 1;
const size_t dim = info->shape.back();
const ptrdiff_t total_blocks = static_cast<ptrdiff_t>(batch_size * nhead);
#pragma omp parallel for
for (ptrdiff_t i = 0; i < ptrdiff_t(info->shape[0]); i++) {
T *x_ = (T *)(x + i * info->x_strides[0]);
T *y_ = (T *)(y + i * info->y_strides[0]);
for (ptrdiff_t block_idx = 0; block_idx < total_blocks; ++block_idx) {
const size_t i = block_idx / nhead; // batch index
const size_t j = block_idx % nhead; // head index
const T *x_ptr = x + i * info->x_strides[0] + j * info->x_strides[1];
T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1];
// [Reduce] sum of x^2 on last dimension
float ss = op::common_cpu::reduce_op::sumSquared(x_, info->shape[1], info->x_strides[1]);
float ss = op::common_cpu::reduce_op::sumSquared(x_ptr, dim, info->x_strides.back());
// 1 / (sqrt(sum/dim + eps))
float rms = 1.f / std::sqrt(ss / (float)(info->shape[1]) + info->epsilon);
float rms = 1.f / std::sqrt(ss / (float)(dim) + info->epsilon);
for (size_t j = 0; j < info->shape[1]; j++) {
for (size_t k = 0; k < dim; k++) {
if constexpr (std::is_same<Tw, float>::value) {
float val = utils::cast<float>(x_[j * info->x_strides[1]]) * w[j] * rms;
y_[j * info->y_strides[1]] = utils::cast<T>(val);
float val = utils::cast<float>(x_ptr[k]) * w[k] * rms;
y_ptr[k] = utils::cast<T>(val);
} else if constexpr (std::is_same<Tw, T>::value) {
float val = utils::cast<float>(x_[j * info->x_strides[1]]) * utils::cast<float>(w[j]) * rms;
y_[j * info->y_strides[1]] = utils::cast<T>(val);
float val = utils::cast<float>(x_ptr[k]) * utils::cast<float>(w[k]) * rms;
y_ptr[k] = utils::cast<T>(val);
} else {
std::abort();
}
......@@ -93,9 +109,9 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.atype == INFINI_DTYPE_F32) {
CHECK_STATUS(rmsnorm(&_info, (float *)y, (float *)x, (float *)w));
CHECK_STATUS(rmsnorm(&_info, (float *)y, (const float *)x, (const float *)w));
} else if (_info.atype == INFINI_DTYPE_F64) {
CHECK_STATUS(rmsnorm(&_info, (double *)y, (double *)x, (double *)w));
CHECK_STATUS(rmsnorm(&_info, (double *)y, (const double *)x, (const double *)w));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -4,16 +4,22 @@
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
__device__ void rmsnormBlock(
Tdata *__restrict__ y,
ptrdiff_t stride_y,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
const Tdata *__restrict__ x,
ptrdiff_t stride_x,
ptrdiff_t stride_x_batch,
ptrdiff_t stride_x_nhead,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
float epsilon) {
// Each block takes care of a row of continuous data of length dim
// Each block takes care of one head in one batch
// Each thread deals with every block_size element in the row
auto y_ptr = y + blockIdx.x * stride_y;
auto x_ptr = x + blockIdx.x * stride_x;
size_t batch_idx = blockIdx.x / nhead;
size_t head_idx = blockIdx.x % nhead;
auto y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead;
auto x_ptr = x + batch_idx * stride_x_batch + head_idx * stride_x_nhead;
auto w_ptr = w;
// Block-reduce sum of x^2
......
......@@ -46,21 +46,39 @@ public:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
if (y_desc->ndim() != 2 || x_desc->ndim() != 2 || w_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
const size_t y_ndim = y_desc->ndim();
const size_t x_ndim = x_desc->ndim();
const size_t w_ndim = w_desc->ndim();
size_t batch = y_desc->shape()[0];
size_t dim = y_desc->shape()[1];
if (x_desc->shape()[0] != batch || x_desc->shape()[1] != dim || w_desc->shape()[0] != dim) {
if (y_ndim != x_ndim || w_ndim != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (w_desc->stride(0) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
size_t batch = 1;
size_t nhead = 1;
size_t dim = 0;
if (y_ndim == 2) {
batch = y_desc->dim(0);
dim = y_desc->dim(1);
if (x_desc->dim(0) != batch || x_desc->dim(1) != dim || w_desc->dim(0) != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
} else if (y_ndim == 3) {
batch = y_desc->dim(0);
nhead = y_desc->dim(1);
dim = y_desc->dim(2);
if (x_desc->dim(0) != batch || x_desc->dim(1) != nhead || x_desc->dim(2) != dim || w_desc->dim(0) != dim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (x_desc->stride(1) != 1 || y_desc->stride(1) != 1) {
// Check contiguity of the last dimension
if (y_desc->stride(y_ndim - 1) != 1 || x_desc->stride(x_ndim - 1) != 1 || w_desc->stride(w_ndim - 1) != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
......
......@@ -11,13 +11,16 @@
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_METAX_KERNEL rmsnormKernel(
Tdata *__restrict__ y,
ptrdiff_t stride_y,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
const Tdata *__restrict__ x,
ptrdiff_t stride_x,
ptrdiff_t stride_x_batch,
ptrdiff_t stride_x_nhead,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
float epsilon) {
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y, x, stride_x, w, dim, epsilon);
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, nhead, dim, epsilon);
}
namespace op::rms_norm::metax {
......@@ -41,11 +44,6 @@ infiniStatus_t Descriptor::create(
CHECK_RESULT(result);
auto info = result.take();
// only support contiguous last dimension
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
std::move(info),
......@@ -57,21 +55,24 @@ infiniStatus_t Descriptor::create(
// launch kernel with different data types
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
uint32_t batch_size, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y,
const void *x, ptrdiff_t stride_x,
uint32_t batch_size, size_t nhead, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
const void *x, ptrdiff_t stride_x_batch, ptrdiff_t stride_x_nhead,
const void *w, infiniDtype_t wtype,
float epsilon,
hcStream_t stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size, BLOCK_SIZE, 0, stream>>>( \
reinterpret_cast<Tdata *>(y), \
stride_y, \
reinterpret_cast<const Tdata *>(x), \
stride_x, \
reinterpret_cast<const Tweight *>(w), \
dim, \
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, stream>>>( \
reinterpret_cast<Tdata *>(y), \
stride_y_batch, \
stride_y_nhead, \
reinterpret_cast<const Tdata *>(x), \
stride_x_batch, \
stride_x_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon)
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
......@@ -102,15 +103,18 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto stride_x = _info.x_strides[0];
auto stride_y = _info.y_strides[0];
auto stride_x_batch = _info.x_strides[0];
auto stride_x_nhead = _info.x_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto stream = reinterpret_cast<hcStream_t>(stream_);
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, stream));
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
......
......@@ -11,13 +11,16 @@
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_MOORE_KERNEL rmsnormKernel(
Tdata *__restrict__ y,
ptrdiff_t stride_y,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
const Tdata *__restrict__ x,
ptrdiff_t stride_x,
ptrdiff_t stride_x_batch,
ptrdiff_t stride_x_nhead,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
float epsilon) {
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y, x, stride_x, w, dim, epsilon);
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, nhead, dim, epsilon);
}
namespace op::rms_norm::moore {
......@@ -41,11 +44,6 @@ infiniStatus_t Descriptor::create(
CHECK_RESULT(result);
auto info = result.take();
// only support contiguous last dimension
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
std::move(info),
......@@ -57,20 +55,23 @@ infiniStatus_t Descriptor::create(
// launch kernel with different data types
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
uint32_t batch_size, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y,
const void *x, ptrdiff_t stride_x,
uint32_t batch_size, size_t nhead, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
const void *x, ptrdiff_t stride_x_batch, ptrdiff_t stride_x_nhead,
const void *w, infiniDtype_t wtype,
float epsilon,
musaStream_t musa_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size, BLOCK_SIZE, 0, musa_stream>>>( \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, musa_stream>>>( \
reinterpret_cast<Tdata *>(y), \
stride_y, \
stride_y_batch, \
stride_y_nhead, \
reinterpret_cast<const Tdata *>(x), \
stride_x, \
stride_x_batch, \
stride_x_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon)
......@@ -102,19 +103,22 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto stride_x = _info.x_strides[0];
auto stride_y = _info.y_strides[0];
auto stride_x_batch = _info.x_strides[0];
auto stride_x_nhead = _info.x_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, musa_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) {
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_2048>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, musa_stream));
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_2048>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, musa_stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
......
......@@ -11,13 +11,16 @@
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_CUDA_KERNEL rmsnormKernel(
Tdata *__restrict__ y,
ptrdiff_t stride_y,
ptrdiff_t stride_y_batch,
ptrdiff_t stride_y_nhead,
const Tdata *__restrict__ x,
ptrdiff_t stride_x,
ptrdiff_t stride_x_batch,
ptrdiff_t stride_x_nhead,
const Tweight *__restrict__ w,
size_t nhead,
size_t dim,
float epsilon) {
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y, x, stride_x, w, dim, epsilon);
rmsnormBlock<BLOCK_SIZE, Tcompute>(y, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, nhead, dim, epsilon);
}
namespace op::rms_norm::nvidia {
......@@ -41,11 +44,6 @@ infiniStatus_t Descriptor::create(
CHECK_RESULT(result);
auto info = result.take();
// only support contiguous last dimension
if (info.x_strides[1] != 1 || info.y_strides[1] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
std::move(info),
......@@ -57,21 +55,24 @@ infiniStatus_t Descriptor::create(
// launch kernel with different data types
template <unsigned int BLOCK_SIZE>
infiniStatus_t launchKernel(
uint32_t batch_size, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y,
const void *x, ptrdiff_t stride_x,
uint32_t batch_size, size_t nhead, size_t dim,
void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
const void *x, ptrdiff_t stride_x_batch, ptrdiff_t stride_x_nhead,
const void *w, infiniDtype_t wtype,
float epsilon,
cudaStream_t cuda_stream) {
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size, BLOCK_SIZE, 0, cuda_stream>>>( \
reinterpret_cast<Tdata *>(y), \
stride_y, \
reinterpret_cast<const Tdata *>(x), \
stride_x, \
reinterpret_cast<const Tweight *>(w), \
dim, \
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, cuda_stream>>>( \
reinterpret_cast<Tdata *>(y), \
stride_y_batch, \
stride_y_nhead, \
reinterpret_cast<const Tdata *>(x), \
stride_x_batch, \
stride_x_nhead, \
reinterpret_cast<const Tweight *>(w), \
nhead, \
dim, \
epsilon)
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
......@@ -102,19 +103,22 @@ infiniStatus_t Descriptor::calculate(
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
}
auto stride_x = _info.x_strides[0];
auto stride_y = _info.y_strides[0];
auto stride_x_batch = _info.x_strides[0];
auto stride_x_nhead = _info.x_strides[1];
auto stride_y_batch = _info.y_strides[0];
auto stride_y_nhead = _info.y_strides[1];
auto dim = _info.dim();
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, cuda_stream));
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_512>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, cuda_stream));
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_512>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream));
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_4096>(batch_size, dim, y, _info.atype, stride_y, x, stride_x, w, _info.wtype, _info.epsilon, cuda_stream));
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_4096>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream));
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
......
......@@ -97,6 +97,10 @@ if __name__ == "__main__":
((500, 4096), None, (8192, 1)),
((4, 512), (1024, 1), (512, 1)),
((4, 512), None, (2048, 1)),
((3, 4, 512), None, None),
((3, 4, 512), None, (4096, 1024, 1)),
((3, 4, 512), (4096, 1024, 1), None),
((3, 4, 512), (4096, 1024, 1), (4096, 1024, 1)),
]
_TENSOR_DTYPES_ = [np.float32, np.float16]
for dtype in _TENSOR_DTYPES_:
......
......@@ -25,11 +25,14 @@ from libinfiniop import (
_TEST_CASES_ = [
# y_shape, x_shape, w_shape, y_stride, x_stride
((1, 4), (1, 4), (4,), None, None),
((1, 4), (1, 4), (4,), None, None),
((16, 2048), (16, 2048), (2048,), None, None),
((2, 4), (2, 4), (4,), None, None),
((2, 2, 4), (2, 2, 4), (4,), None, None),
((2, 2, 4), (2, 2, 4), (4,), (12, 8, 1), (12, 8, 1)),
((16, 2048), (16, 2048), (2048,), None, None),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)),
((4, 4, 2048), (4, 4, 2048), (2048,), None, None),
((4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1)),
((4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1)),
]
# w (weight) types
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment