Commit e120c9b6 authored by Casper Hansen's avatar Casper Hansen
Browse files

Use CUDA stream

parent 1aa8aebd
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
*/ */
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "gemm_cuda.h" #include "gemm_cuda.h"
#include "dequantize.cuh" #include "dequantize.cuh"
#include <cuda_fp16.h> #include <cuda_fp16.h>
...@@ -439,6 +440,7 @@ torch::Tensor gemm_forward_cuda( ...@@ -439,6 +440,7 @@ torch::Tensor gemm_forward_cuda(
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>()); auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>()); auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
int group_size = num_in_channels / _scaling_factors.size(0); int group_size = num_in_channels / _scaling_factors.size(0);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (num_out_channels % 64 != 0) if (num_out_channels % 64 != 0)
throw std::invalid_argument("OC is not multiple of cta_N = 64"); throw std::invalid_argument("OC is not multiple of cta_N = 64");
...@@ -456,7 +458,7 @@ torch::Tensor gemm_forward_cuda( ...@@ -456,7 +458,7 @@ torch::Tensor gemm_forward_cuda(
// threadIdx.x: 32 // threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2] // threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2); dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>( gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block, 0, stream>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
} }
else if (num_out_channels % 64 == 0) else if (num_out_channels % 64 == 0)
...@@ -467,7 +469,7 @@ torch::Tensor gemm_forward_cuda( ...@@ -467,7 +469,7 @@ torch::Tensor gemm_forward_cuda(
// threadIdx.x: 32 // threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2] // threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2); dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>( gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block, 0, stream>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
} }
return _out_feats.sum(0); return _out_feats.sum(0);
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <stdio.h> #include <stdio.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "gemv_cuda.h" #include "gemv_cuda.h"
#define VECTORIZE_FACTOR 8 #define VECTORIZE_FACTOR 8
#define Q_VECTORIZE_FACTOR 8 #define Q_VECTORIZE_FACTOR 8
...@@ -224,9 +225,10 @@ torch::Tensor gemv_forward_cuda( ...@@ -224,9 +225,10 @@ torch::Tensor gemv_forward_cuda(
int blockDim_z = num_out_feats; int blockDim_z = num_out_feats;
dim3 num_blocks(1, num_out_channels / 4, num_out_feats); dim3 num_blocks(1, num_out_channels / 4, num_out_feats);
dim3 num_threads(32, 4); dim3 num_threads(32, 4);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (group_size == 64) if (group_size == 64)
{ {
gemv_kernel_g64<<<num_blocks, num_threads>>>( gemv_kernel_g64<<<num_blocks, num_threads, 0, stream>>>(
// pointers // pointers
in_feats, kernel, zeros, scaling_factors, out_feats, in_feats, kernel, zeros, scaling_factors, out_feats,
// constants // constants
...@@ -235,7 +237,7 @@ torch::Tensor gemv_forward_cuda( ...@@ -235,7 +237,7 @@ torch::Tensor gemv_forward_cuda(
} }
else if (group_size == 128) else if (group_size == 128)
{ {
gemv_kernel_g128<<<num_blocks, num_threads>>>( gemv_kernel_g128<<<num_blocks, num_threads, 0, stream>>>(
// pointers // pointers
in_feats, kernel, zeros, scaling_factors, out_feats, in_feats, kernel, zeros, scaling_factors, out_feats,
// constants // constants
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment