Commit 9f0ae734 authored by Ziminli's avatar Ziminli
Browse files

issue/428: update the rope implementation on Ascend, Cambricon, and Kunlun to...

issue/428: update the rope implementation on Ascend, Cambricon, and Kunlun to use the refactored interface and return unimplemented error for NEOX-style algorithm
parent f6e8476b
...@@ -13,11 +13,16 @@ infiniStatus_t Descriptor::create( ...@@ -13,11 +13,16 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc, infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc, infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc, infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) { infiniopTensorDescriptor_t cos_desc,
infiniopRoPEAlgo_t algo) {
auto handle_ascned = reinterpret_cast<device::ascend::Handle *>(handle); auto handle_ascned = reinterpret_cast<device::ascend::Handle *>(handle);
auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc); auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo);
CHECK_RESULT(result); CHECK_RESULT(result);
if (algo != INFINIOP_ROPE_ALGO_GPT_J) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
size_t workspace_size = 0; size_t workspace_size = 0;
*desc_ptr = new Descriptor(std::move(result.take()), workspace_size, nullptr, handle_ascned->device, handle_ascned->device_id); *desc_ptr = new Descriptor(std::move(result.take()), workspace_size, nullptr, handle_ascned->device, handle_ascned->device_id);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
......
...@@ -13,13 +13,18 @@ infiniStatus_t Descriptor::create( ...@@ -13,13 +13,18 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc, infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc, infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc, infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) { infiniopTensorDescriptor_t cos_desc,
infiniopRoPEAlgo_t algo) {
auto handle = reinterpret_cast<device::bang::Handle *>(handle_); auto handle = reinterpret_cast<device::bang::Handle *>(handle_);
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc); auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo);
CHECK_RESULT(info); CHECK_RESULT(info);
if (algo != INFINIOP_ROPE_ALGO_GPT_J) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
// Create descriptor // Create descriptor
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
info.take(), info.take(),
......
...@@ -118,11 +118,16 @@ infiniStatus_t Descriptor::create( ...@@ -118,11 +118,16 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t x_desc, infiniopTensorDescriptor_t x_desc,
infiniopTensorDescriptor_t pos_desc, infiniopTensorDescriptor_t pos_desc,
infiniopTensorDescriptor_t sin_desc, infiniopTensorDescriptor_t sin_desc,
infiniopTensorDescriptor_t cos_desc) { infiniopTensorDescriptor_t cos_desc,
infiniopRoPEAlgo_t algo) {
auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc); auto result = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo);
CHECK_RESULT(result); CHECK_RESULT(result);
if (algo != INFINIOP_ROPE_ALGO_GPT_J) {
return INFINI_STATUS_NOT_IMPLEMENTED;
}
// Create descriptor // Create descriptor
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
result.take(), result.take(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment