Commit 90fcbcc9 authored by Catheriany's avatar Catheriany
Browse files

issue/150: rope算子测例

parent 439ba32f
......@@ -7,6 +7,7 @@
*/
DECLARE_INFINIOP_TEST(gemm)
DECLARE_INFINIOP_TEST(random_sample)
DECLARE_INFINIOP_TEST(rope)
#define REGISTER_INFINIOP_TEST(name) \
{ \
......@@ -24,7 +25,7 @@ DECLARE_INFINIOP_TEST(random_sample)
{ \
REGISTER_INFINIOP_TEST(gemm) \
REGISTER_INFINIOP_TEST(random_sample) \
}
REGISTER_INFINIOP_TEST(rope)}
namespace infiniop_test {
......
#include "ops.hpp"
#include "utils.hpp"
#include <infinirt.h>
#include <iomanip>
#include <iostream>
namespace infiniop_test::rope {
struct Test::Attributes {
std::shared_ptr<Tensor> y;
std::shared_ptr<Tensor> x;
std::shared_ptr<Tensor> pos_ids;
std::shared_ptr<Tensor> sin_table;
std::shared_ptr<Tensor> cos_table;
std::shared_ptr<Tensor> ans;
};
std::shared_ptr<Test> Test::build(
std::unordered_map<std::string, std::vector<uint8_t>> attributes,
std::unordered_map<std::string, std::shared_ptr<Tensor>> tensors,
double rtol, double atol) {
auto test = std::shared_ptr<Test>(new Test(rtol, atol));
test->_attributes = new Attributes();
if (tensors.find("y") == tensors.end()
|| tensors.find("x") == tensors.end()
|| tensors.find("pos_ids") == tensors.end()
|| tensors.find("sin_table") == tensors.end()
|| tensors.find("cos_table") == tensors.end()
|| tensors.find("ans") == tensors.end()) {
throw std::runtime_error("Invalid Test");
}
test->_attributes->y = tensors["y"];
test->_attributes->x = tensors["x"];
test->_attributes->pos_ids = tensors["pos_ids"];
test->_attributes->sin_table = tensors["sin_table"];
test->_attributes->cos_table = tensors["cos_table"];
test->_attributes->ans = tensors["ans"];
return test;
}
std::shared_ptr<infiniop_test::Result> Test::run(
infiniopHandle_t handle, infiniDevice_t device, int device_id, size_t warm_ups, size_t iterations) {
infiniopRoPEDescriptor_t op_desc;
auto y = _attributes->y->to(device, device_id);
auto x = _attributes->x->to(device, device_id);
auto pos_ids = _attributes->pos_ids->to(device, device_id);
auto sin_table = _attributes->sin_table->to(device, device_id);
auto cos_table = _attributes->cos_table->to(device, device_id);
CHECK_OR(infiniopCreateRoPEDescriptor(handle, &op_desc,
y->desc(),
x->desc(),
pos_ids->desc(),
sin_table->desc(),
cos_table->desc()),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to create op descriptor."));
size_t workspace_size;
CHECK_OR(infiniopGetRoPEWorkspaceSize(op_desc, &workspace_size),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size."));
void *workspace;
CHECK_OR(infinirtMalloc(&workspace, workspace_size),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace."));
CHECK_OR(infiniopRoPE(op_desc, workspace, workspace_size,
y->data(),
x->data(),
pos_ids->data(),
sin_table->data(),
cos_table->data(),
nullptr),
return TEST_FAILED(OP_EXECUTION_FAILED, "Failed during execution."));
try {
allClose(y, _attributes->ans, _rtol, _atol);
} catch (const std::exception &e) {
return TEST_FAILED(RESULT_INCORRECT, e.what());
}
double elapsed_time = 0.;
elapsed_time = benchmark(
[=]() {
infiniopRoPE(
op_desc, workspace, workspace_size,
y->data(),
x->data(),
pos_ids->data(),
sin_table->data(),
cos_table->data(),
nullptr);
},
warm_ups, iterations);
return TEST_PASSED(elapsed_time);
}
std::vector<std::string> Test::attribute_names() {
return {};
}
std::vector<std::string> Test::tensor_names() {
return {"y", "x", "pos_ids", "sin_table", "cos_table", "ans"};
}
std::string Test::toString() const {
std::ostringstream oss;
oss << op_name() << std::endl;
oss << "- y: " << _attributes->y->info() << std::endl;
oss << "- x: " << _attributes->x->info() << std::endl;
oss << "- pos_ids: " << _attributes->pos_ids->info() << std::endl;
oss << "- sin_table: " << _attributes->sin_table->info() << std::endl;
oss << "- cos_table: " << _attributes->cos_table->info() << std::endl;
oss << std::scientific << std::setprecision(2);
oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl;
return oss.str();
}
Test::~Test() {
delete _attributes;
}
} // namespace infiniop_test::rope
......@@ -46,8 +46,8 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
const Tdata *sin_table,
const Tdata *cos_table,
cudaStream_t stream) {
auto dimx = unsigned int(info.seqlen),
dimy = unsigned int(info.nhead);
auto dimx = static_cast<unsigned int>(info.seqlen);
auto dimy = static_cast<unsigned int>(info.nhead);
int nthreads = std::max(int(info.table_dim), block_size);
ropeThreadPerItem<<<dim3(dimx, dimy), nthreads, 0, stream>>>(
......
from ast import List
import numpy as np
import gguf
from enum import Enum, auto
from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides
def rotary_embedding(t, sin, cos):
dh = t.shape[2]
assert dh % 2 == 0, "Embedding dimension must be even."
t_even = t[..., 0::2] # [seq_len, n_head, dh // 2]
t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2]
cos = np.expand_dims(cos, axis=1) # [seq_len, 1, dh // 2]
sin = np.expand_dims(sin, axis=1) # [seq_len, 1, dh // 2]
t_out_even = t_even * cos - t_odd * sin
t_out_odd = t_even * sin + t_odd * cos
t_out = np.empty_like(t)
t_out[..., 0::2] = t_out_even
t_out[..., 1::2] = t_out_odd
return t_out
def sin_cos_table(pos, dim, theta, dtype):
assert dim % 2 == 0, "Embedding dimension must be even."
freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(np.float32) / dim))
angles = np.outer(pos, freqs)
sin_vals = np.sin(angles).astype(dtype)
cos_vals = np.cos(angles).astype(dtype)
return sin_vals, cos_vals
class RoPETestCase(InfiniopTestCase):
def __init__(
self,
y: np.ndarray,
x: np.ndarray,
stride_y: List[int] | None,
stride_x: List[int] | None,
pos_ids: np.ndarray,
sin_table: np.ndarray,
cos_table: np.ndarray,
):
super().__init__("rope")
self.y = y
self.x = x
self.stride_y = stride_y
self.stride_x = stride_x
self.pos_ids = pos_ids
self.sin_table = sin_table
self.cos_table = cos_table
def write_test(self, test_writer: "InfiniopTestWriter"):
super().write_test(test_writer)
test_writer.add_tensor(
test_writer.gguf_key("y"), self.y, raw_dtype=np_dtype_to_ggml(self.y.dtype)
)
test_writer.add_tensor(
test_writer.gguf_key("x"), self.x, raw_dtype=np_dtype_to_ggml(self.x.dtype)
)
if self.stride_y is not None:
test_writer.add_array(test_writer.gguf_key("y.strides"), self.stride_y)
if self.stride_x is not None:
test_writer.add_array(test_writer.gguf_key("x.strides"), self.stride_x)
test_writer.add_tensor(
test_writer.gguf_key("pos_ids"), self.pos_ids, raw_dtype=np_dtype_to_ggml(self.pos_ids.dtype)
)
test_writer.add_tensor(
test_writer.gguf_key("sin_table"), self.sin_table, raw_dtype=np_dtype_to_ggml(self.sin_table.dtype)
)
test_writer.add_tensor(
test_writer.gguf_key("cos_table"), self.cos_table, raw_dtype=np_dtype_to_ggml(self.cos_table.dtype)
)
ans = rotary_embedding(
self.x.astype(np.float64),
self.sin_table.astype(np.float64),
self.cos_table.astype(np.float64),
)
test_writer.add_tensor(
test_writer.gguf_key("ans"), ans, raw_dtype=gguf.GGMLQuantizationType.F64
)
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# (shape, x_strides, y_strides)
((1, 32, 128), None, None),
((10, 32, 64), None, None),
# # 昇腾暂不满足这个用例,最后一维度 <=32 会有问题,可能与其核心
# # 接口 GatherMask 的内部实现相关,目前 48 64 128 都可以支持
((4, 1, 32), gguf_strides(64, 64, 1), None),
((11, 33, 128), None, gguf_strides(8000, 200, 1)),
((3, 32, 128), gguf_strides(8000, 200, 1), gguf_strides(7000, 128, 1)),
]
_TENSOR_DTYPES_ = [np.float16, np.float32]
class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE_X = auto()
_INPLACE_ = [
Inplace.OUT_OF_PLACE,
Inplace.INPLACE_X,
]
if __name__ == "__main__":
test_writer = InfiniopTestWriter("rope.gguf")
test_cases = []
for dtype in _TENSOR_DTYPES_:
for shape, stride_x, stride_y in _TEST_CASES_:
for inplace in _INPLACE_:
x = np.random.rand(*shape).astype(dtype)
if inplace == Inplace.INPLACE_X:
y = x
else:
y = np.random.rand(*shape).astype(dtype)
pos_ids = np.arange(0, x.shape[0], dtype=np.int32)
sin_table, cos_table = sin_cos_table(pos_ids, x.shape[2], theta=1e5, dtype=dtype)
test_case = RoPETestCase(
y=y,
x=x,
stride_y=stride_y,
stride_x=stride_x,
pos_ids=pos_ids,
sin_table=sin_table,
cos_table=cos_table,
)
test_cases.append(test_case)
test_writer.add_tests(test_cases)
test_writer.save()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment