Commit 9f0ae734 authored by Ziminli's avatar Ziminli
Browse files

issue/428: update the rope implementation on Ascend, Cambricon, and Kunlun to...

issue/428: update the rope implementation on Ascend, Cambricon, and Kunlun to use the refactored interface and return unimplemented error for NEOX-style algorithm
parent f6e8476b
......@@ -13,11 +13,16 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
infiniopTensorDescriptor_t cos_desc,
infiniopRoPEAlgo_t algo) {
auto handle_ascned = reinterpret_cast<device::ascend::Handle *>(handle);
auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo);
CHECK_RESULT(result);
if (algo != INFINIOP_ROPE_ALGO_GPT_J) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
size_t workspace_size = 0;
*desc_ptr = new Descriptor(std::move(result.take()), workspace_size, nullptr, handle_ascned->device, handle_ascned->device_id);
return INFINI_STATUS_SUCCESS;
......
......@@ -13,13 +13,18 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
infiniopTensorDescriptor_t cos_desc,
infiniopRoPEAlgo_t algo) {
auto handle = reinterpret_cast<device::bang::Handle *>(handle_);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo);
CHECK_RESULT(info);
if (algo != INFINIOP_ROPE_ALGO_GPT_J) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
// Create descriptor
*desc_ptr = new Descriptor(
info.take(),
......
......@@ -118,11 +118,16 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) {
infiniopTensorDescriptor_t cos_desc,
infiniopRoPEAlgo_t algo) {
auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc);
auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo);
CHECK_RESULT(result);
if (algo != INFINIOP_ROPE_ALGO_GPT_J) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
// Create descriptor
*desc_ptr = new Descriptor(
result.take(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment