Commit e120c9b6 authored by Casper Hansen's avatar Casper Hansen
Browse files

Use CUDA stream

parent 1aa8aebd
......@@ -10,6 +10,7 @@
*/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "gemm_cuda.h"
#include "dequantize.cuh"
#include <cuda_fp16.h>
......@@ -439,6 +440,7 @@ torch::Tensor gemm_forward_cuda(
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
int group_size = num_in_channels / _scaling_factors.size(0);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (num_out_channels % 64 != 0)
throw std::invalid_argument("OC is not multiple of cta_N = 64");
......@@ -456,7 +458,7 @@ torch::Tensor gemm_forward_cuda(
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block, 0, stream>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
else if (num_out_channels % 64 == 0)
......@@ -467,7 +469,7 @@ torch::Tensor gemm_forward_cuda(
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block, 0, stream>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
return _out_feats.sum(0);
......
......@@ -13,6 +13,7 @@
#include <cuda_fp16.h>
#include <stdio.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "gemv_cuda.h"
#define VECTORIZE_FACTOR 8
#define Q_VECTORIZE_FACTOR 8
......@@ -224,9 +225,10 @@ torch::Tensor gemv_forward_cuda(
int blockDim_z = num_out_feats;
dim3 num_blocks(1, num_out_channels / 4, num_out_feats);
dim3 num_threads(32, 4);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (group_size == 64)
{
gemv_kernel_g64<<<num_blocks, num_threads>>>(
gemv_kernel_g64<<<num_blocks, num_threads, 0, stream>>>(
// pointers
in_feats, kernel, zeros, scaling_factors, out_feats,
// constants
......@@ -235,7 +237,7 @@ torch::Tensor gemv_forward_cuda(
}
else if (group_size == 128)
{
gemv_kernel_g128<<<num_blocks, num_threads>>>(
gemv_kernel_g128<<<num_blocks, num_threads, 0, stream>>>(
// pointers
in_feats, kernel, zeros, scaling_factors, out_feats,
// constants
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment