"git@developer.sourcefind.cn:yaoyuping/nndetection.git" did not exist on "3f162f5e10fc029ed67bab9a03b9a732c72e2cc0"
Commit 0450fb1e authored by PanZezhong's avatar PanZezhong
Browse files

issue/161 rope和causal softmax支持非原地

parent 61734617
...@@ -10,14 +10,13 @@ ...@@ -10,14 +10,13 @@
#include "infiniop/ops/expand.h" #include "infiniop/ops/expand.h"
#include "infiniop/ops/gemm.h" #include "infiniop/ops/gemm.h"
#include "infiniop/ops/global_avg_pool.h" #include "infiniop/ops/global_avg_pool.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/max_pool.h" #include "infiniop/ops/max_pool.h"
#include "infiniop/ops/mlp.h" #include "infiniop/ops/mlp.h"
#include "infiniop/ops/random_sample.h" #include "infiniop/ops/random_sample.h"
#include "infiniop/ops/rearrange.h" #include "infiniop/ops/rearrange.h"
#include "infiniop/ops/relu.h" #include "infiniop/ops/relu.h"
#include "infiniop/ops/rms_norm.h" #include "infiniop/ops/rms_norm.h"
#include "infiniop/ops/rotary_embedding.h" #include "infiniop/ops/rope.h"
#include "infiniop/ops/swiglu.h" #include "infiniop/ops/swiglu.h"
#include "infiniop/tensor_descriptor.h" #include "infiniop/tensor_descriptor.h"
......
...@@ -5,17 +5,21 @@ ...@@ -5,17 +5,21 @@
typedef struct InfiniopDescriptor *infiniopCausalSoftmaxDescriptor_t; typedef struct InfiniopDescriptor *infiniopCausalSoftmaxDescriptor_t;
__C __export infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(infiniopHandle_t handle, __C __export infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
infiniopCausalSoftmaxDescriptor_t *desc_ptr, infiniopHandle_t handle,
infiniopTensorDescriptor_t y_desc); infiniopCausalSoftmaxDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc);
__C __export infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDescriptor_t desc, size_t *size); __C __export infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopCausalSoftmax(infiniopCausalSoftmaxDescriptor_t desc, __C __export infiniStatus_t infiniopCausalSoftmax(
void *workspace, infiniopCausalSoftmaxDescriptor_t desc,
size_t workspace_size, void *workspace,
void *data, size_t workspace_size,
void *stream); void *y,
const void *x,
void *stream);
__C __export infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxDescriptor_t desc); __C __export infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxDescriptor_t desc);
......
#ifndef __INFINIOP_ROTARY_EMBEDDING_API_H__ #ifndef __INFINIOP_ROPE_API_H__
#define __INFINIOP_ROTARY_EMBEDDING_API_H__ #define __INFINIOP_ROPE_API_H__
#include "../operator_descriptor.h" #include "../operator_descriptor.h"
...@@ -8,7 +8,8 @@ typedef struct InfiniopDescriptor *infiniopRoPEDescriptor_t; ...@@ -8,7 +8,8 @@ typedef struct InfiniopDescriptor *infiniopRoPEDescriptor_t;
__C __export infiniStatus_t infiniopCreateRoPEDescriptor( __C __export infiniStatus_t infiniopCreateRoPEDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
infiniopRoPEDescriptor_t *desc_ptr, infiniopRoPEDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t t, infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t pos_ids, infiniopTensorDescriptor_t pos_ids,
infiniopTensorDescriptor_t sin_table, infiniopTensorDescriptor_t sin_table,
infiniopTensorDescriptor_t cos_table); infiniopTensorDescriptor_t cos_table);
...@@ -19,7 +20,8 @@ __C __export infiniStatus_t infiniopRoPE( ...@@ -19,7 +20,8 @@ __C __export infiniStatus_t infiniopRoPE(
infiniopRoPEDescriptor_t desc, infiniopRoPEDescriptor_t desc,
void *workspace, void *workspace,
size_t workspace_size, size_t workspace_size,
void *t, void *y,
const void *x,
void const *pos_ids, void const *pos_ids,
void const *sin_table, void const *sin_table,
void const *cos_table, void const *cos_table,
......
...@@ -32,11 +32,13 @@ ...@@ -32,11 +32,13 @@
static infiniStatus_t create( \ static infiniStatus_t create( \
infiniopHandle_t handle, \ infiniopHandle_t handle, \
Descriptor **desc_ptr, \ Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc); \ infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc); \
\ \
infiniStatus_t calculate( \ infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \ void *workspace, size_t workspace_size, \
void *data, \ void *y, \
const void *x, \
void *stream) const; \ void *stream) const; \
}; \ }; \
} }
......
...@@ -9,44 +9,46 @@ Descriptor::~Descriptor() {} ...@@ -9,44 +9,46 @@ Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create( infiniStatus_t Descriptor::create(
infiniopHandle_t handle, infiniopHandle_t handle,
Descriptor **desc_ptr, Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc) { infiniopTensorDescriptor_t y_desc,
auto result = CausalSoftmaxInfo::create(y_desc); infiniopTensorDescriptor_t x_desc) {
auto result = CausalSoftmaxInfo::create(y_desc, x_desc);
CHECK_RESULT(result); CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id); *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
template <typename T> template <typename T>
infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) { infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *y, const T *x) {
#pragma omp parallel for #pragma omp parallel for
for (ptrdiff_t index = 0; index < ptrdiff_t(info->batch_size * info->seq_len); index++) { for (ptrdiff_t index = 0; index < ptrdiff_t(info->batch_size * info->seq_len); index++) {
size_t ind = index; size_t batch = index / info->seq_len;
size_t offset = 0; size_t i = (index % info->seq_len);
size_t i = (ind % info->seq_len); ptrdiff_t y_offset = batch * info->y_stride_b + i * info->y_stride_i;
offset += (ind % info->seq_len) * info->stride_i; ptrdiff_t x_offset = batch * info->x_stride_b + i * info->x_stride_i;
ind /= info->seq_len; T *y_ = y + y_offset;
offset += (ind % info->batch_size) * info->stride_b; const T *x_ = x + x_offset;
for (size_t j = info->total_seq_len - info->seq_len + i + 1; j < info->total_seq_len; j++) { for (size_t j = info->total_seq_len - info->seq_len + i + 1; j < info->total_seq_len; j++) {
if constexpr (std::is_same<T, fp16_t>::value) { if constexpr (std::is_same<T, fp16_t>::value) {
data[offset + j * info->stride_j] = utils::cast<fp16_t>(0.0f); y_[j * info->y_stride_j] = utils::cast<fp16_t>(0.0f);
} else { } else {
data[offset + j * info->stride_j] = 0.0f; y_[j * info->y_stride_j] = 0.0f;
} }
} }
float val = op::common_cpu::reduce_op::max(&data[offset], info->total_seq_len - info->seq_len + i + 1, info->stride_j); float val = op::common_cpu::reduce_op::max(x_, info->total_seq_len - info->seq_len + i + 1, info->x_stride_j);
for (size_t j = 0; j <= info->total_seq_len - info->seq_len + i; j++) { for (size_t j = 0; j <= info->total_seq_len - info->seq_len + i; j++) {
if constexpr (std::is_same<T, fp16_t>::value) { if constexpr (std::is_same<T, fp16_t>::value) {
data[offset + j * info->stride_j] = utils::cast<fp16_t>(std::exp(utils::cast<float>(data[offset + j * info->stride_j]) - val)); y_[j * info->y_stride_j] = utils::cast<fp16_t>(std::exp(utils::cast<float>(x_[j * info->x_stride_j]) - val));
} else { } else {
data[offset + j * info->stride_j] = std::exp(data[offset + j * info->stride_j] - val); y_[j * info->y_stride_j] = std::exp(x_[j * info->x_stride_j] - val);
} }
} }
float sum = op::common_cpu::reduce_op::sum(&data[offset], info->total_seq_len - info->seq_len + i + 1, info->stride_j); float sum = op::common_cpu::reduce_op::sum(y_, info->total_seq_len - info->seq_len + i + 1, info->y_stride_j);
for (size_t j = 0; j <= info->total_seq_len - info->seq_len + i; j++) { for (size_t j = 0; j <= info->total_seq_len - info->seq_len + i; j++) {
if constexpr (std::is_same<T, fp16_t>::value) { if constexpr (std::is_same<T, fp16_t>::value) {
data[offset + j * info->stride_j] = utils::cast<fp16_t>(utils::cast<float>(data[offset + j * info->stride_j]) / sum); y_[j * info->y_stride_j] = utils::cast<fp16_t>(utils::cast<float>(y_[j * info->y_stride_j]) / sum);
} else { } else {
data[offset + j * info->stride_j] = data[offset + j * info->stride_j] / sum; y_[j * info->y_stride_j] = y_[y_offset + j * info->y_stride_j] / sum;
} }
} }
} }
...@@ -56,13 +58,14 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) { ...@@ -56,13 +58,14 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) {
infiniStatus_t Descriptor::calculate( infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size, void *workspace, size_t workspace_size,
void *data, void *y,
const void *x,
void *stream) const { void *stream) const {
if (_info.dtype == INFINI_DTYPE_F16) { if (_info.dtype == INFINI_DTYPE_F16) {
CHECK_STATUS(causal_softmax<fp16_t>(&_info, (fp16_t *)data)); CHECK_STATUS(causal_softmax<fp16_t>(&_info, (fp16_t *)y, (const fp16_t *)x));
} else if (_info.dtype == INFINI_DTYPE_F32) { } else if (_info.dtype == INFINI_DTYPE_F32) {
CHECK_STATUS(causal_softmax<float>(&_info, (float *)data)); CHECK_STATUS(causal_softmax<float>(&_info, (float *)y, (const float *)x));
} else { } else {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
......
...@@ -13,45 +13,69 @@ class CausalSoftmaxInfo { ...@@ -13,45 +13,69 @@ class CausalSoftmaxInfo {
public: public:
infiniDtype_t dtype; infiniDtype_t dtype;
size_t batch_size; size_t batch_size;
ptrdiff_t stride_b;
size_t seq_len; size_t seq_len;
ptrdiff_t stride_i;
size_t total_seq_len; size_t total_seq_len;
ptrdiff_t stride_j;
static utils::Result<CausalSoftmaxInfo> create(infiniopTensorDescriptor_t y_desc) { ptrdiff_t y_stride_b;
ptrdiff_t y_stride_i;
ptrdiff_t y_stride_j;
ptrdiff_t x_stride_b;
ptrdiff_t x_stride_i;
ptrdiff_t x_stride_j;
static utils::Result<CausalSoftmaxInfo> create(infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t x_desc) {
auto dtype = y_desc->dtype(); auto dtype = y_desc->dtype();
if (y_desc->dtype() != INFINI_DTYPE_F16 && y_desc->dtype() != INFINI_DTYPE_F32) { if (dtype != x_desc->dtype()) {
return INFINI_STATUS_BAD_TENSOR_DTYPE; return INFINI_STATUS_BAD_TENSOR_DTYPE;
} }
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
auto ndim = y_desc->ndim();
if (ndim != x_desc->ndim()) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (y_desc->ndim() != 2 && y_desc->ndim() != 3) { if (ndim != 2 && ndim != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE; return INFINI_STATUS_BAD_TENSOR_SHAPE;
} }
if (y_desc->shape()[y_desc->ndim() - 1] < y_desc->shape()[y_desc->ndim() - 2]) { auto shape = y_desc->shape();
if (!SAME_VEC(y_desc->shape(), x_desc->shape())) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (shape[ndim - 1] < shape[ndim - 2]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE; return INFINI_STATUS_BAD_TENSOR_SHAPE;
} }
size_t batch_size = 1; size_t batch_size = 1;
ptrdiff_t stride_b = 0; size_t seq_len = shape[ndim - 2];
size_t seq_len = y_desc->shape()[y_desc->ndim() - 2]; size_t total_seq_len = shape[ndim - 1];
ptrdiff_t stride_i = y_desc->strides()[y_desc->ndim() - 2]; ptrdiff_t y_stride_b = 0,
size_t total_seq_len = y_desc->shape()[y_desc->ndim() - 1]; x_stride_b = 0;
ptrdiff_t stride_j = y_desc->strides()[y_desc->ndim() - 1]; ptrdiff_t y_stride_i = y_desc->stride(ndim - 2),
if (y_desc->ndim() == 3) { y_stride_j = y_desc->stride(ndim - 1);
stride_b = y_desc->strides()[0]; ptrdiff_t x_stride_i = x_desc->stride(ndim - 2),
batch_size = y_desc->shape()[0]; x_stride_j = x_desc->stride(ndim - 1);
if (ndim == 3) {
y_stride_b = y_desc->stride(0);
x_stride_b = x_desc->stride(0);
batch_size = shape[0];
} }
return utils::Result<CausalSoftmaxInfo>(CausalSoftmaxInfo{ return utils::Result<CausalSoftmaxInfo>(CausalSoftmaxInfo{
dtype, dtype,
batch_size, batch_size,
stride_b,
seq_len, seq_len,
stride_i,
total_seq_len, total_seq_len,
stride_j}); y_stride_b,
y_stride_i,
y_stride_j,
x_stride_b,
x_stride_i,
x_stride_j});
} }
}; };
......
...@@ -9,14 +9,16 @@ ...@@ -9,14 +9,16 @@
__C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor( __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
infiniopCausalSoftmaxDescriptor_t *desc_ptr, infiniopCausalSoftmaxDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc) { infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
#define CREATE(CASE, NAMESPACE) \ #define CREATE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
return op::causal_softmax::NAMESPACE::Descriptor::create( \ return op::causal_softmax::NAMESPACE::Descriptor::create( \
handle, \ handle, \
reinterpret_cast<op::causal_softmax::NAMESPACE::Descriptor **>(desc_ptr), \ reinterpret_cast<op::causal_softmax::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc); y_desc, \
x_desc);
switch (handle->device) { switch (handle->device) {
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
...@@ -96,12 +98,17 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe ...@@ -96,12 +98,17 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopCausalSoftmax(infiniopCausalSoftmaxDescriptor_t desc, void *workspace, size_t workspace_size, void *data, void *stream) { __C infiniStatus_t infiniopCausalSoftmax(
infiniopCausalSoftmaxDescriptor_t desc,
void *workspace, size_t workspace_size,
void *y,
const void *x,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \ #define CALCULATE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
return reinterpret_cast<op::causal_softmax::NAMESPACE::Descriptor *>(desc)->calculate( \ return reinterpret_cast<op::causal_softmax::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, data, stream); workspace, workspace_size, y, x, stream);
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_CPU_API #ifdef ENABLE_CPU_API
......
#include "../../operator.h" #include "../../operator.h"
#include "../../handle.h" #include "../../handle.h"
#include "infiniop/ops/rotary_embedding.h" #include "infiniop/ops/rope.h"
__C infiniStatus_t infiniopCreateRoPEDescriptor( __C infiniStatus_t infiniopCreateRoPEDescriptor(
infiniopHandle_t handle, infiniopRoPEDescriptor_t *desc_ptr, infiniopHandle_t handle,
infiniopTensorDescriptor_t t, infiniopTensorDescriptor_t pos_ids, infiniopRoPEDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t pos_ids,
infiniopTensorDescriptor_t sin_table, infiniopTensorDescriptor_t sin_table,
infiniopTensorDescriptor_t cos_table) { infiniopTensorDescriptor_t cos_table) {
switch (handle->device) { switch (handle->device) {
...@@ -91,11 +94,16 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, ...@@ -91,11 +94,16 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
} }
__C infiniStatus_t infiniopRoPE(infiniopRoPEDescriptor_t desc, __C infiniStatus_t infiniopRoPE(
void *workspace, size_t workspace_size, infiniopRoPEDescriptor_t desc,
void *t, const void *pos_ids, void *workspace,
const void *sin_table, const void *cos_table, size_t workspace_size,
void *stream) { void *y,
const void *x,
void const *pos_ids,
void const *sin_table,
void const *cos_table,
void *stream) {
switch (desc->device_type) { switch (desc->device_type) {
#ifdef ENABLE_CPU #ifdef ENABLE_CPU
case DevCpu: case DevCpu:
......
...@@ -16,18 +16,20 @@ from libinfiniop import ( ...@@ -16,18 +16,20 @@ from libinfiniop import (
get_tolerance, get_tolerance,
profile_operation, profile_operation,
) )
from enum import Enum, auto
# ============================================================================== # ==============================================================================
# Configuration (Internal Use Only) # Configuration (Internal Use Only)
# ============================================================================== # ==============================================================================
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES = [ _TEST_CASES_ = [
# x_shape, x_stride # shape, x_stride, y_stride
((32, 512), None), ((3, 3), None, None),
((32, 512), (1024, 1)), ((32, 512), None, None),
((32, 5, 5), None), ((32, 512), (1024, 1), (1024, 1)),
((32, 20, 512), None), ((32, 5, 5), None, None),
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续 ((32, 20, 512), None, None),
((32, 20, 512), (20480, 512, 1), None), # Ascend 暂不支持非连续
] ]
# Data types used for testing # Data types used for testing
...@@ -38,6 +40,23 @@ _TOLERANCE_MAP = { ...@@ -38,6 +40,23 @@ _TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-2}, torch.float16: {"atol": 0, "rtol": 1e-2},
} }
class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE_X = auto()
_INPLACE = [
Inplace.OUT_OF_PLACE,
Inplace.INPLACE_X,
]
_TEST_CASES = [
test_case + (inplace_item,)
for test_case in _TEST_CASES_
for inplace_item in _INPLACE
]
DEBUG = False DEBUG = False
PROFILE = False PROFILE = False
NUM_PRERUN = 10 NUM_PRERUN = 10
...@@ -59,12 +78,21 @@ def causal_softmax(x): ...@@ -59,12 +78,21 @@ def causal_softmax(x):
return torch.nn.functional.softmax(masked, dim=-1).to(type) return torch.nn.functional.softmax(masked, dim=-1).to(type)
def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16): def test(
lib,
handle,
torch_device,
shape,
x_stride=None,
y_stride=None,
inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float16,
):
print( print(
f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{dtype}" f"Testing CausalSoftmax on {torch_device} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{dtype} inplace:{inplace}"
) )
x = torch.rand(x_shape, dtype=dtype).to(torch_device) x = torch.rand(shape, dtype=dtype).to(torch_device)
ans = causal_softmax(x) ans = causal_softmax(x)
...@@ -72,10 +100,18 @@ def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16) ...@@ -72,10 +100,18 @@ def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16)
x_tensor = to_tensor(x, lib) x_tensor = to_tensor(x, lib)
if inplace == Inplace.INPLACE_X:
y = x
y_tensor = x_tensor
else:
y = torch.zeros(shape, dtype=dtype).to(torch_device)
y = rearrange_if_needed(y, y_stride)
y_tensor = to_tensor(y, lib)
descriptor = infiniopCausalSoftmaxDescriptor_t() descriptor = infiniopCausalSoftmaxDescriptor_t()
check_error( check_error(
lib.infiniopCreateCausalSoftmaxDescriptor( lib.infiniopCreateCausalSoftmaxDescriptor(
handle, ctypes.byref(descriptor), x_tensor.descriptor handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor
) )
) )
...@@ -96,6 +132,7 @@ def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16) ...@@ -96,6 +132,7 @@ def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16)
descriptor, descriptor,
workspace.data_ptr() if workspace is not None else None, workspace.data_ptr() if workspace is not None else None,
workspace_size.value, workspace_size.value,
y_tensor.data,
x_tensor.data, x_tensor.data,
None, None,
) )
...@@ -105,8 +142,8 @@ def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16) ...@@ -105,8 +142,8 @@ def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16)
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG: if DEBUG:
debug(x, ans, atol=atol, rtol=rtol) debug(y, ans, atol=atol, rtol=rtol)
assert torch.allclose(x, ans, atol=atol, rtol=rtol) assert torch.allclose(y, ans, atol=atol, rtol=rtol)
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
......
...@@ -232,5 +232,5 @@ if __name__ == "__main__": ...@@ -232,5 +232,5 @@ if __name__ == "__main__":
# Execute tests # Execute tests
for device in get_test_devices(args): for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES) test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment