Commit 6b717b30 authored by PanZezhong's avatar PanZezhong
Browse files

issue/4/fix 修改info构建

parent a8955429
......@@ -17,11 +17,11 @@ infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc) {
CausalSoftmaxInfo info;
CHECK_STATUS(createCausalSoftmaxInfo(&info, y_desc));
auto info = CausalSoftmaxInfo::create(y_desc);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::cuda::Handle *>(handle)->internal()},
info, 0, handle->device, handle->device_id);
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
......@@ -42,7 +42,7 @@ infiniStatus_t launchKernel(void *data, infiniDtype_t dtype, size_t batch_size,
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *data,
void *stream_) {
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment