Commit 40e07381 authored by xiabo's avatar xiabo
Browse files

Adapt to rocm FT的修改补充

parent ab8c95cb
...@@ -209,7 +209,8 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh ...@@ -209,7 +209,8 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh
ft::check_cuda_error(cudaGetDeviceProperties(cuda_device_prop_ptr.get(), device_id)); ft::check_cuda_error(cudaGetDeviceProperties(cuda_device_prop_ptr.get(), device_id));
if (std::is_same<T, half>::value) { if (std::is_same<T, half>::value) {
cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F); // cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F);
cublas_wrapper->setGemmConfig(CUDA_R_16F, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F);
} }
else if (std::is_same<T, float>::value) { else if (std::is_same<T, float>::value) {
cublas_wrapper->setFP32GemmConfig(); cublas_wrapper->setFP32GemmConfig();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment