Commit 3758afb5 authored by Azure's avatar Azure
Browse files

fix some dequant function dosen't support multi gpu bug

parent 3ed8a043
...@@ -292,6 +292,7 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de ...@@ -292,6 +292,7 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device) { torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size; int num_blocks = data.numel() / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options); auto data_gpu = torch::empty({data.numel()}, options);
...@@ -330,6 +331,7 @@ torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device de ...@@ -330,6 +331,7 @@ torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device de
torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device) { torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size; int num_blocks = data.numel() / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options); auto data_gpu = torch::empty({data.numel()}, options);
...@@ -348,6 +350,7 @@ torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device de ...@@ -348,6 +350,7 @@ torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device de
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device) { torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size; int num_blocks = data.numel() / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options); auto data_gpu = torch::empty({data.numel()}, options);
...@@ -366,6 +369,7 @@ torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device de ...@@ -366,6 +369,7 @@ torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device de
torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device) { torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size; int num_blocks = data.numel() / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous); auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options); auto data_gpu = torch::empty({data.numel()}, options);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment