Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
077f39a0
Commit
077f39a0
authored
Sep 14, 2023
by
Casper
Browse files
Merge branch 'main' into pr/27
parents
faedf517
d76125bf
Changes
34
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
4396 additions
and
13 deletions
+4396
-13
awq_cuda/attention/cuda_bf16_wrapper.h
awq_cuda/attention/cuda_bf16_wrapper.h
+23
-0
awq_cuda/attention/decoder_masked_multihead_attention.cu
awq_cuda/attention/decoder_masked_multihead_attention.cu
+154
-0
awq_cuda/attention/decoder_masked_multihead_attention.h
awq_cuda/attention/decoder_masked_multihead_attention.h
+184
-0
awq_cuda/attention/decoder_masked_multihead_attention_template.hpp
...attention/decoder_masked_multihead_attention_template.hpp
+1608
-0
awq_cuda/attention/decoder_masked_multihead_attention_utils.h
...cuda/attention/decoder_masked_multihead_attention_utils.h
+1786
-0
awq_cuda/attention/ft_attention.cpp
awq_cuda/attention/ft_attention.cpp
+182
-0
awq_cuda/attention/ft_attention.h
awq_cuda/attention/ft_attention.h
+15
-0
awq_cuda/pybind.cpp
awq_cuda/pybind.cpp
+8
-1
awq_cuda/quantization/gemm_cuda_gen.cu
awq_cuda/quantization/gemm_cuda_gen.cu
+4
-2
awq_cuda/quantization/gemv_cuda.cu
awq_cuda/quantization/gemv_cuda.cu
+249
-0
awq_cuda/quantization/gemv_cuda.h
awq_cuda/quantization/gemv_cuda.h
+9
-0
examples/basic_quant.py
examples/basic_quant.py
+1
-1
examples/benchmark.py
examples/benchmark.py
+132
-0
setup.py
setup.py
+41
-9
No files found.
awq_cuda/attention/cuda_bf16_wrapper.h
0 → 100644
View file @
077f39a0
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif
awq_cuda/attention/decoder_masked_multihead_attention.cu
0 → 100644
View file @
077f39a0
// Adapted from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "decoder_masked_multihead_attention.h"
#include "decoder_masked_multihead_attention_utils.h"
#include "cuda_bf16_wrapper.h"
#include <assert.h>
#include <float.h>
#include <type_traits>
#include "decoder_masked_multihead_attention_template.hpp"
////////////////////////////////////////////////////////////////////////////////////////////////////
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
auto kernel = mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, DO_CROSS_ATTENTION>; \
if (smem_sz >= 48 * 1024) { \
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
} \
dim3 grid(params.num_heads, params.batch_size); \
kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
////////////////////////////////////////////////////////////////////////////////////////////////////
// !!! Specialize the launcher for Cross attention
template
<
typename
T
,
int
Dh
,
int
Dh_MAX
,
typename
KERNEL_PARAMS_TYPE
>
void
mmha_launch_kernel
(
const
KERNEL_PARAMS_TYPE
&
params
,
const
cudaStream_t
&
stream
)
{
constexpr
int
THREADS_PER_VALUE
=
Dh_MAX
*
sizeof
(
T
)
/
16
;
constexpr
bool
DO_CROSS_ATTENTION
=
std
::
is_same
<
KERNEL_PARAMS_TYPE
,
Cross_multihead_attention_params
<
T
>>::
value
;
int
tlength
=
(
DO_CROSS_ATTENTION
)
?
params
.
memory_max_len
:
params
.
timestep
;
// printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION);
if
(
tlength
<
32
)
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
4
,
THREADS_PER_VALUE
,
64
,
DO_CROSS_ATTENTION
,
stream
);
}
else
if
(
tlength
<
2048
)
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
2
,
THREADS_PER_VALUE
,
128
,
DO_CROSS_ATTENTION
,
stream
);
}
else
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
1
,
THREADS_PER_VALUE
,
256
,
DO_CROSS_ATTENTION
,
stream
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#undef MMHA_LAUNCH_KERNEL
template
<
typename
T
,
typename
KERNEL_PARAMS_TYPE
>
void
multihead_attention_
(
const
KERNEL_PARAMS_TYPE
&
params
,
const
cudaStream_t
&
stream
)
{
switch
(
params
.
hidden_size_per_head
)
{
case
32
:
mmha_launch_kernel
<
T
,
32
,
32
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
48
:
mmha_launch_kernel
<
T
,
48
,
64
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
64
:
mmha_launch_kernel
<
T
,
64
,
64
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
80
:
mmha_launch_kernel
<
T
,
80
,
128
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
96
:
mmha_launch_kernel
<
T
,
96
,
128
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
112
:
mmha_launch_kernel
<
T
,
112
,
128
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
128
:
mmha_launch_kernel
<
T
,
128
,
128
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
160
:
mmha_launch_kernel
<
T
,
160
,
256
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
192
:
mmha_launch_kernel
<
T
,
192
,
256
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
224
:
mmha_launch_kernel
<
T
,
224
,
256
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
256
:
mmha_launch_kernel
<
T
,
256
,
256
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
default:
assert
(
false
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
float
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
float
,
Masked_multihead_attention_params
<
float
>>
(
params
,
stream
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
uint16_t
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
uint16_t
,
Masked_multihead_attention_params
<
uint16_t
>>
(
params
,
stream
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
__nv_bfloat16
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
__nv_bfloat16
,
Masked_multihead_attention_params
<
__nv_bfloat16
>>
(
params
,
stream
);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
float
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
float
,
Cross_multihead_attention_params
<
float
>>
(
params
,
stream
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
uint16_t
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
uint16_t
,
Cross_multihead_attention_params
<
uint16_t
>>
(
params
,
stream
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
__nv_bfloat16
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
__nv_bfloat16
,
Cross_multihead_attention_params
<
__nv_bfloat16
>>
(
params
,
stream
);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
awq_cuda/attention/decoder_masked_multihead_attention.h
0 → 100644
View file @
077f39a0
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while (0)
////////////////////////////////////////////////////////////////////////////////////////////////////
// The structure of parameters for the masked multihead attention kernel.
//
// We use the following terminology to describe the different dimensions.
//
// B: Batch size (number of sequences),
// L: Sequence length,
// D: Hidden dimension,
// H: Number of heads,
// Dh: Hidden dimension per head - Dh = D / H.
template
<
typename
T
>
struct
Multihead_attention_params_base
{
// The output buffer. Dimensions B x D.
T
*
out
=
nullptr
;
// The input Qs and the associated bias. Dimensions B x D and D, resp.
const
T
*
q
=
nullptr
,
*
q_bias
=
nullptr
;
// The input Ks and the associated bias. Dimensions B x D and D, resp.
const
T
*
k
=
nullptr
,
*
k_bias
=
nullptr
;
// The input Vs and the associated bias. Dimensions B x D and D, resp.
const
T
*
v
=
nullptr
,
*
v_bias
=
nullptr
;
// The cache for the Ks. The size must be at least B x L x D.
T
*
k_cache
=
nullptr
;
// The cache for the Vs. The size must be at least B x L x D.
T
*
v_cache
=
nullptr
;
// The indirections to use for cache when beam sampling.
const
int
*
cache_indir
=
nullptr
;
// Stride to handle the case when KQV is a single buffer
int
stride
=
0
;
// The batch size.
int
batch_size
=
0
;
// The beam width
int
beam_width
=
0
;
// The sequence length.
int
memory_max_len
=
0
;
// The number of heads (H).
int
num_heads
=
0
;
// The number of heads for KV cache.
int
num_kv_heads
=
0
;
// The hidden dimension per head (Dh).
int
hidden_size_per_head
=
0
;
// The per-head latent space reserved for rotary embeddings.
int
rotary_embedding_dim
=
0
;
bool
neox_rotary_style
=
false
;
float
rotary_base
=
0.0
f
;
// The maximum length of input sentences.
int
max_input_length
=
0
;
// The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
int
timestep
=
0
;
// The current timestep of each sentences (support different timestep for different sentences)
// The 1.f / sqrt(Dh). Computed on the host.
float
inv_sqrt_dh
=
0.0
f
;
// Used when we have some input context like gpt
const
int
*
total_padding_tokens
=
nullptr
;
const
bool
*
masked_tokens
=
nullptr
;
const
int
*
prefix_prompt_lengths
=
nullptr
;
int
max_prefix_prompt_length
=
0
;
const
T
*
relative_attention_bias
=
nullptr
;
int
relative_attention_bias_stride
=
0
;
// The slope per head of linear position bias to attention score (H).
const
float
*
linear_bias_slopes
=
nullptr
;
const
T
*
ia3_key_weights
=
nullptr
;
const
T
*
ia3_value_weights
=
nullptr
;
const
int
*
ia3_tasks
=
nullptr
;
const
float
*
qkv_scale_out
=
nullptr
;
const
float
*
attention_out_scale
=
nullptr
;
int
int8_mode
=
0
;
};
template
<
typename
T
,
bool
CROSS_ATTENTION
>
struct
Multihead_attention_params
:
public
Multihead_attention_params_base
<
T
>
{
// output cross attentions
float
*
cross_attention_out
=
nullptr
;
int
max_decoder_seq_len
=
0
;
bool
is_return_cross_attentions
=
false
;
// allows to exist attention eary
bool
*
finished
=
nullptr
;
// required in case of cross attention
// will need it here till if constexpr in c++17
int
*
memory_length_per_sample
=
nullptr
;
// required in case of masked attention with different length
const
int
*
length_per_sample
=
nullptr
;
};
template
<
typename
T
>
struct
Multihead_attention_params
<
T
,
true
>:
public
Multihead_attention_params_base
<
T
>
{
// output cross attentions
float
*
cross_attention_out
=
nullptr
;
int
max_decoder_seq_len
=
0
;
bool
is_return_cross_attentions
=
false
;
// allows to exist attention eary
bool
*
finished
=
nullptr
;
// required in case of cross attention
int
*
memory_length_per_sample
=
nullptr
;
// required in case of masked attention with different length
const
int
*
length_per_sample
=
nullptr
;
};
template
<
class
T
>
using
Masked_multihead_attention_params
=
Multihead_attention_params
<
T
,
false
>
;
template
<
class
T
>
using
Cross_multihead_attention_params
=
Multihead_attention_params
<
T
,
true
>
;
template
<
typename
T
>
struct
outputCrossAttentionParam
{
// max decoder output length
int
max_decoder_seq_len
=
0
;
T
*
cross_attention_out
=
nullptr
;
bool
is_return_cross_attentions
=
false
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
float
>&
params
,
const
cudaStream_t
&
stream
);
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
uint16_t
>&
params
,
const
cudaStream_t
&
stream
);
#ifdef ENABLE_BF16
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
__nv_bfloat16
>&
params
,
const
cudaStream_t
&
stream
);
#endif
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
float
>&
params
,
const
cudaStream_t
&
stream
);
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
uint16_t
>&
params
,
const
cudaStream_t
&
stream
);
#ifdef ENABLE_BF16
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
__nv_bfloat16
>&
params
,
const
cudaStream_t
&
stream
);
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
awq_cuda/attention/decoder_masked_multihead_attention_template.hpp
0 → 100644
View file @
077f39a0
This diff is collapsed.
Click to expand it.
awq_cuda/attention/decoder_masked_multihead_attention_utils.h
0 → 100644
View file @
077f39a0
This diff is collapsed.
Click to expand it.
awq_cuda/attention/ft_attention.cpp
0 → 100644
View file @
077f39a0
// Adapted from NVIDIA/FasterTransformer and FlashAttention
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"
#include <c10/cuda/CUDAGuard.h>
#include "ft_attention.h"
#include "decoder_masked_multihead_attention.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \
if (TYPE == at::ScalarType::Half) { \
using scalar_t = at::Half; \
__VA_ARGS__(); \
} else if (TYPE == at::ScalarType::BFloat16) { \
using scalar_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (TYPE == at::ScalarType::Float) { \
using scalar_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
}
template
<
typename
T
>
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
T
>&
params
,
const
cudaStream_t
&
stream
);
template
<
typename
T
>
void
cross_multihead_attention
(
const
Masked_multihead_attention_params
<
T
>&
params
,
const
cudaStream_t
&
stream
);
template
<
typename
T
>
struct
SATypeConverter
{
using
Type
=
T
;
};
template
<
>
struct
SATypeConverter
<
at
::
Half
>
{
using
Type
=
uint16_t
;
};
template
<
>
struct
SATypeConverter
<
at
::
BFloat16
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
typename
T
>
void
set_params
(
Masked_multihead_attention_params
<
T
>
&
params
,
const
size_t
batch_size
,
const
size_t
nheads
,
const
size_t
nheads_kv
,
const
size_t
memory_max_seqlen
,
const
size_t
headdim
,
const
int
timestep
,
const
int
rotary_embedding_dim
,
const
float
rotary_base
,
const
bool
neox_rotary_style
,
const
int
qkv_batch_stride
,
T
*
q_ptr
,
T
*
k_ptr
,
T
*
v_ptr
,
T
*
k_cache_ptr
,
T
*
v_cache_ptr
,
int
*
length_per_sample
,
float
*
alibi_slopes_ptr
,
T
*
out_ptr
)
{
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
params
.
q
=
q_ptr
;
params
.
k
=
k_ptr
;
params
.
v
=
v_ptr
;
params
.
q_bias
=
nullptr
;
params
.
k_bias
=
nullptr
;
params
.
v_bias
=
nullptr
;
params
.
k_cache
=
k_cache_ptr
;
params
.
v_cache
=
v_cache_ptr
;
params
.
linear_bias_slopes
=
alibi_slopes_ptr
;
params
.
out
=
out_ptr
;
params
.
cache_indir
=
nullptr
;
params
.
stride
=
qkv_batch_stride
;
params
.
batch_size
=
batch_size
;
params
.
beam_width
=
1
;
params
.
memory_max_len
=
memory_max_seqlen
;
params
.
num_heads
=
nheads
;
params
.
num_kv_heads
=
nheads_kv
;
params
.
hidden_size_per_head
=
headdim
;
params
.
rotary_embedding_dim
=
rotary_embedding_dim
;
params
.
rotary_base
=
rotary_base
;
params
.
neox_rotary_style
=
neox_rotary_style
;
params
.
timestep
=
timestep
;
params
.
inv_sqrt_dh
=
1.
f
/
sqrt
(
float
(
headdim
));
params
.
total_padding_tokens
=
nullptr
;
params
.
masked_tokens
=
nullptr
;
params
.
prefix_prompt_lengths
=
nullptr
;
params
.
max_prefix_prompt_length
=
0
;
params
.
relative_attention_bias
=
nullptr
;
params
.
relative_attention_bias_stride
=
0
;
params
.
cross_attention_out
=
nullptr
;
params
.
max_decoder_seq_len
=
0
;
params
.
is_return_cross_attentions
=
false
;
params
.
finished
=
nullptr
;
params
.
memory_length_per_sample
=
nullptr
;
params
.
length_per_sample
=
length_per_sample
;
}
torch
::
Tensor
single_query_attention
(
const
torch
::
Tensor
q
,
const
torch
::
Tensor
k
,
const
torch
::
Tensor
v
,
torch
::
Tensor
k_cache
,
torch
::
Tensor
v_cache
,
c10
::
optional
<
const
torch
::
Tensor
>
length_per_sample_
,
c10
::
optional
<
const
torch
::
Tensor
>
alibi_slopes_
,
const
int
timestep
,
const
int
rotary_embedding_dim
,
const
float
rotary_base
,
// neox_rotary_style = not interleaved
const
bool
neox_rotary_style
)
{
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
k
);
CHECK_DEVICE
(
v
);
CHECK_DEVICE
(
k_cache
);
CHECK_DEVICE
(
v_cache
);
int
batch_size
=
v_cache
.
size
(
0
);
int
nheads
=
q
.
size
(
1
);
int
nheads_kv
=
v_cache
.
size
(
1
);
int
memory_max_seqlen
=
v_cache
.
size
(
2
);
int
headdim
=
v_cache
.
size
(
3
);
CHECK_SHAPE
(
q
,
batch_size
,
nheads
,
headdim
);
CHECK_SHAPE
(
k
,
batch_size
,
nheads_kv
,
headdim
);
CHECK_SHAPE
(
v
,
batch_size
,
nheads_kv
,
headdim
);
CHECK_SHAPE
(
v_cache
,
batch_size
,
nheads_kv
,
memory_max_seqlen
,
headdim
);
// k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
int
packsize
=
k_cache
.
dtype
()
==
torch
::
kFloat32
?
4
:
8
;
CHECK_SHAPE
(
k_cache
,
batch_size
,
nheads_kv
,
headdim
/
packsize
,
memory_max_seqlen
,
packsize
);
TORCH_CHECK
(
q
.
stride
(
2
)
==
1
&&
q
.
stride
(
1
)
==
headdim
);
TORCH_CHECK
(
k
.
stride
(
2
)
==
1
&&
k
.
stride
(
1
)
==
headdim
);
TORCH_CHECK
(
v
.
stride
(
2
)
==
1
&&
v
.
stride
(
1
)
==
headdim
);
// TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0));
CHECK_CONTIGUOUS
(
v_cache
);
CHECK_CONTIGUOUS
(
k_cache
);
if
(
length_per_sample_
.
has_value
())
{
auto
length_per_sample
=
length_per_sample_
.
value
();
CHECK_DEVICE
(
length_per_sample
);
CHECK_SHAPE
(
length_per_sample
,
batch_size
);
CHECK_CONTIGUOUS
(
length_per_sample
);
TORCH_CHECK
(
length_per_sample
.
dtype
()
==
torch
::
kInt32
);
}
if
(
alibi_slopes_
.
has_value
())
{
auto
alibi_slopes
=
alibi_slopes_
.
value
();
CHECK_DEVICE
(
alibi_slopes
);
CHECK_SHAPE
(
alibi_slopes
,
nheads
);
CHECK_CONTIGUOUS
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
);
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q
.
get_device
()};
torch
::
Tensor
out
=
torch
::
empty_like
(
q
);
DISPATCH_FLOAT_AND_HALF_AND_BF16
(
q
.
scalar_type
(),
"single_query_attention"
,
[
&
]
{
using
DataType
=
typename
SATypeConverter
<
scalar_t
>::
Type
;
Masked_multihead_attention_params
<
DataType
>
params
;
set_params
(
params
,
batch_size
,
nheads
,
nheads_kv
,
memory_max_seqlen
,
headdim
,
timestep
,
rotary_embedding_dim
,
rotary_base
,
neox_rotary_style
,
q
.
stride
(
0
),
reinterpret_cast
<
DataType
*>
(
q
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
k
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
v
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
k_cache
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
v_cache
.
data_ptr
()),
length_per_sample_
.
has_value
()
?
length_per_sample_
.
value
().
data_ptr
<
int
>
()
:
nullptr
,
alibi_slopes_
.
has_value
()
?
alibi_slopes_
.
value
().
data_ptr
<
float
>
()
:
nullptr
,
reinterpret_cast
<
DataType
*>
(
out
.
data_ptr
()));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
masked_multihead_attention
(
params
,
stream
);
});
return
out
;
}
\ No newline at end of file
awq_cuda/attention/ft_attention.h
0 → 100644
View file @
077f39a0
#pragma once
#include <torch/extension.h>
torch
::
Tensor
single_query_attention
(
const
torch
::
Tensor
q
,
const
torch
::
Tensor
k
,
const
torch
::
Tensor
v
,
torch
::
Tensor
k_cache
,
torch
::
Tensor
v_cache
,
c10
::
optional
<
const
torch
::
Tensor
>
length_per_sample_
,
c10
::
optional
<
const
torch
::
Tensor
>
alibi_slopes_
,
const
int
timestep
,
const
int
rotary_embedding_dim
=
0
,
const
float
rotary_base
=
10000
.
0
f
,
const
bool
neox_rotary_style
=
true
);
\ No newline at end of file
awq_cuda/pybind.cpp
View file @
077f39a0
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "attention/ft_attention.h"
#include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemv_cuda.h"
#include "position_embedding/pos_encoding.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"layernorm_forward_cuda"
,
&
layernorm_forward_cuda
,
"FasterTransformer layernorm kernel"
);
m
.
def
(
"gemm_forward_cuda"
,
&
gemm_forward_cuda
,
"Quantized GEMM kernel."
);
m
.
def
(
"gemv_forward_cuda"
,
&
gemv_forward_cuda
,
"Quantized GEMV kernel."
);
m
.
def
(
"rotary_embedding_neox"
,
&
rotary_embedding_neox
,
"Apply GPT-NeoX style rotary embedding to query and key"
);
}
m
.
def
(
"single_query_attention"
,
&
single_query_attention
,
"Attention with a single query"
,
py
::
arg
(
"q"
),
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"length_per_sample_"
),
py
::
arg
(
"alibi_slopes_"
),
py
::
arg
(
"timestep"
),
py
::
arg
(
"rotary_embedding_dim"
)
=
0
,
py
::
arg
(
"rotary_base"
)
=
10000.0
f
,
py
::
arg
(
"neox_rotary_style"
)
=
true
);
}
\ No newline at end of file
awq_cuda/quantization/gemm_cuda_gen.cu
View file @
077f39a0
...
...
@@ -10,6 +10,7 @@
*/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "gemm_cuda.h"
#include "dequantize.cuh"
#include <cuda_fp16.h>
...
...
@@ -439,6 +440,7 @@ torch::Tensor gemm_forward_cuda(
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
int
group_size
=
num_in_channels
/
_scaling_factors
.
size
(
0
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
num_out_channels
%
64
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of cta_N = 64"
);
...
...
@@ -456,7 +458,7 @@ torch::Tensor gemm_forward_cuda(
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
>>>
(
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
else
if
(
num_out_channels
%
64
==
0
)
...
...
@@ -467,7 +469,7 @@ torch::Tensor gemm_forward_cuda(
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
gemm_forward_4bit_cuda_m16n64k32
<<<
num_blocks
,
threads_per_block
>>>
(
gemm_forward_4bit_cuda_m16n64k32
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
return
_out_feats
.
sum
(
0
);
...
...
awq_cuda/quantization/gemv_cuda.cu
0 → 100644
View file @
077f39a0
// Inspired by https://github.com/ankan-ban/llama_cu_awq
/*
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <cuda_fp16.h>
#include <stdio.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "gemv_cuda.h"
#define VECTORIZE_FACTOR 8
#define Q_VECTORIZE_FACTOR 8
#define PACK_FACTOR 8
#define WARP_SIZE 32
// Reduce sum within the warp using the tree reduction algorithm.
__device__
__forceinline__
float
warp_reduce_sum
(
float
sum
)
{
#pragma unroll
for
(
int
i
=
4
;
i
>=
0
;
i
--
){
sum
+=
__shfl_down_sync
(
0xffffffff
,
sum
,
1
<<
i
);
}
/*
// Equivalent to the following tree reduction implementation:
sum += __shfl_down_sync(0xffffffff, sum, 16);
sum += __shfl_down_sync(0xffffffff, sum, 8);
sum += __shfl_down_sync(0xffffffff, sum, 4);
sum += __shfl_down_sync(0xffffffff, sum, 2);
sum += __shfl_down_sync(0xffffffff, sum, 1);
*/
return
sum
;
}
__device__
__forceinline__
int
make_divisible
(
int
c
,
int
divisor
){
return
(
c
+
divisor
-
1
)
/
divisor
;
}
/*
Computes GEMV (group_size = 64).
Args:
inputs: vector of shape [batch_size, IC];
weight: matrix of shape [OC, IC / 8];
output: vector of shape [OC];
zeros: matrix of shape [OC, IC / group_size / 8];
scaling_factors: matrix of shape [OC, IC / group_size];
Notes:
One cannot infer group_size from the shape of scaling factors.
the second dimension is rounded up to a multiple of PACK_FACTOR.
*/
__global__
void
gemv_kernel_g64
(
const
float4
*
_inputs
,
const
uint32_t
*
weight
,
const
uint32_t
*
zeros
,
const
half
*
scaling_factors
,
half
*
_outputs
,
const
int
IC
,
const
int
OC
){
const
int
group_size
=
64
;
float
psum
=
0
;
const
int
batch_idx
=
blockIdx
.
z
;
const
int
oc_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
float4
*
inputs
=
_inputs
+
batch_idx
*
IC
/
PACK_FACTOR
;
half
*
outputs
=
_outputs
+
batch_idx
*
OC
;
// This is essentially zeros_w.
const
int
num_groups_packed
=
make_divisible
(
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
),
2
)
*
2
;
const
int
weight_w
=
IC
/
PACK_FACTOR
;
// TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address
const
int
zeros_w
=
make_divisible
(
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
),
2
)
*
2
;
// consistent with input shape
const
int
sf_w
=
make_divisible
(
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
),
2
)
*
2
*
PACK_FACTOR
;
// if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w, sf_w);
// tile size: 4 OC x 1024 IC per iter
for
(
int
packed_group_idx
=
0
;
packed_group_idx
<
num_groups_packed
/
2
;
packed_group_idx
++
){
// 1024 numbers in one iteration across warp. Need 1024 / group_size zeros.
uint64_t
packed_zeros
=
*
reinterpret_cast
<
const
uint64_t
*>
(
zeros
+
oc_idx
*
zeros_w
+
packed_group_idx
*
2
);
uint32_t
packed_weights
[
4
];
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4)
*
((
float4
*
)(
packed_weights
))
=
*
((
float4
*
)(
weight
+
oc_idx
*
weight_w
+
packed_group_idx
*
(
WARP_SIZE
*
4
)
+
threadIdx
.
x
*
4
));
// load scaling factors
// g64: two threads -> 64 numbers -> 1 group; 1 warp = 16 groups.
float
scaling_factor
=
__half2float
(
scaling_factors
[
oc_idx
*
sf_w
+
packed_group_idx
*
16
+
(
threadIdx
.
x
/
2
)]);
float
current_zeros
=
(
float
)((
packed_zeros
>>
(
threadIdx
.
x
/
2
*
4
))
&
0xF
);
int
inputs_ptr_delta
=
packed_group_idx
*
WARP_SIZE
*
4
+
threadIdx
.
x
*
4
;
const
float4
*
inputs_ptr
=
inputs
+
inputs_ptr_delta
;
// multiply 32 weights with 32 inputs
#pragma unroll
for
(
int
ic_0
=
0
;
ic_0
<
4
;
ic_0
++
){
// iterate over different uint32_t packed_weights in this loop
uint32_t
current_packed_weight
=
packed_weights
[
ic_0
];
half
packed_inputs
[
PACK_FACTOR
];
// each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
if
(
inputs_ptr_delta
+
ic_0
<
IC
/
PACK_FACTOR
)
{
*
((
float4
*
)
packed_inputs
)
=
*
(
inputs_ptr
+
ic_0
);
#pragma unroll
for
(
int
ic_1
=
0
;
ic_1
<
PACK_FACTOR
;
ic_1
++
){
// iterate over 8 numbers packed within each uint32_t number
float
current_single_weight_fp
=
(
float
)(
current_packed_weight
&
0xF
);
float
dequantized_weight
=
scaling_factor
*
(
current_single_weight_fp
-
current_zeros
);
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros);
psum
+=
dequantized_weight
*
__half2float
(
packed_inputs
[
ic_1
]);
current_packed_weight
=
current_packed_weight
>>
4
;
}
}
}
}
psum
=
warp_reduce_sum
(
psum
);
if
(
threadIdx
.
x
==
0
)
{
outputs
[
oc_idx
]
=
__float2half
(
psum
);
}
}
/*
Computes GEMV (group_size = 128).
Args:
inputs: vector of shape [batch_size, IC];
weight: matrix of shape [OC, IC / 8];
output: vector of shape [OC];
zeros: matrix of shape [OC, IC / group_size / 8];
scaling_factors: matrix of shape [OC, IC / group_size];
Notes:
One cannot infer group_size from the shape of scaling factors.
the second dimension is rounded up to a multiple of PACK_FACTOR.
*/
__global__
void
gemv_kernel_g128
(
const
float4
*
_inputs
,
const
uint32_t
*
weight
,
const
uint32_t
*
zeros
,
const
half
*
scaling_factors
,
half
*
_outputs
,
const
int
IC
,
const
int
OC
){
const
int
group_size
=
128
;
float
psum
=
0
;
const
int
batch_idx
=
blockIdx
.
z
;
const
int
oc_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
float4
*
inputs
=
_inputs
+
batch_idx
*
IC
/
PACK_FACTOR
;
half
*
outputs
=
_outputs
+
batch_idx
*
OC
;
const
int
num_groups_packed
=
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
);
const
int
weight_w
=
IC
/
PACK_FACTOR
;
// TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address
const
int
zeros_w
=
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
);
// consistent with input shape
const
int
sf_w
=
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
)
*
PACK_FACTOR
;
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w);
// tile size: 4 OC x 1024 IC per iter
for
(
int
packed_group_idx
=
0
;
packed_group_idx
<
num_groups_packed
;
packed_group_idx
++
){
// 1024 numbers in one iteration across warp. Need 1024 / group_size zeros.
uint32_t
packed_zeros
=
*
(
zeros
+
oc_idx
*
zeros_w
+
packed_group_idx
);
uint32_t
packed_weights
[
4
];
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4)
*
((
float4
*
)(
packed_weights
))
=
*
((
float4
*
)(
weight
+
oc_idx
*
weight_w
+
packed_group_idx
*
(
WARP_SIZE
*
4
)
+
threadIdx
.
x
*
4
));
// load scaling factors
// g128: four threads -> 128 numbers -> 1 group; 1 warp = 8 groups.
float
scaling_factor
=
__half2float
(
scaling_factors
[
oc_idx
*
sf_w
+
packed_group_idx
*
8
+
(
threadIdx
.
x
/
4
)]);
float
current_zeros
=
(
float
)((
packed_zeros
>>
(
threadIdx
.
x
/
4
*
4
))
&
0xF
);
int
inputs_ptr_delta
=
packed_group_idx
*
WARP_SIZE
*
4
+
threadIdx
.
x
*
4
;
const
float4
*
inputs_ptr
=
inputs
+
inputs_ptr_delta
;
// multiply 32 weights with 32 inputs
#pragma unroll
for
(
int
ic_0
=
0
;
ic_0
<
4
;
ic_0
++
){
// iterate over different uint32_t packed_weights in this loop
uint32_t
current_packed_weight
=
packed_weights
[
ic_0
];
half
packed_inputs
[
PACK_FACTOR
];
// each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
if
(
inputs_ptr_delta
+
ic_0
<
IC
/
PACK_FACTOR
)
{
*
((
float4
*
)
packed_inputs
)
=
*
(
inputs_ptr
+
ic_0
);
#pragma unroll
for
(
int
ic_1
=
0
;
ic_1
<
PACK_FACTOR
;
ic_1
++
){
// iterate over 8 numbers packed within each uint32_t number
float
current_single_weight_fp
=
(
float
)(
current_packed_weight
&
0xF
);
float
dequantized_weight
=
scaling_factor
*
(
current_single_weight_fp
-
current_zeros
);
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros);
psum
+=
dequantized_weight
*
__half2float
(
packed_inputs
[
ic_1
]);
current_packed_weight
=
current_packed_weight
>>
4
;
}
}
}
}
psum
=
warp_reduce_sum
(
psum
);
if
(
threadIdx
.
x
==
0
)
{
outputs
[
oc_idx
]
=
__float2half
(
psum
);
}
}
/*
Computes GEMV (PyTorch interface).
Args:
_in_feats: tensor of shape [B, IC];
_kernel: int tensor of shape [OC, IC // 8];
_zeros: int tensor of shape [OC, IC // G // 8];
_scaling_factors: tensor of shape [OC, IC // G];
blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;
Returns:
out_feats: tensor of shape [B, OC];
*/
torch
::
Tensor
gemv_forward_cuda
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
group_size
)
{
int
num_in_feats
=
_in_feats
.
size
(
0
);
int
num_in_channels
=
_in_feats
.
size
(
1
);
// int kernel_volume = _out_in_map.size(1);
auto
in_feats
=
reinterpret_cast
<
float4
*>
(
_in_feats
.
data_ptr
<
at
::
Half
>
());
auto
kernel
=
reinterpret_cast
<
uint32_t
*>
(
_kernel
.
data_ptr
<
int
>
());
auto
zeros
=
reinterpret_cast
<
uint32_t
*>
(
_zeros
.
data_ptr
<
int
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
// auto out_in_map = _out_in_map.data_ptr<int>();
auto
options
=
torch
::
TensorOptions
().
dtype
(
_in_feats
.
dtype
()).
device
(
_in_feats
.
device
());
// kernel is [OC, IC]
at
::
Tensor
_out_feats
=
torch
::
empty
({
num_in_feats
,
_kernel
.
size
(
0
)},
options
);
int
num_out_feats
=
_out_feats
.
size
(
-
2
);
int
num_out_channels
=
_out_feats
.
size
(
-
1
);
auto
out_feats
=
reinterpret_cast
<
half
*>
(
_out_feats
.
data_ptr
<
at
::
Half
>
());
int
blockDim_z
=
num_out_feats
;
dim3
num_blocks
(
1
,
num_out_channels
/
4
,
num_out_feats
);
dim3
num_threads
(
32
,
4
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
group_size
==
64
)
{
gemv_kernel_g64
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
// pointers
in_feats
,
kernel
,
zeros
,
scaling_factors
,
out_feats
,
// constants
num_in_channels
,
num_out_channels
);
}
else
if
(
group_size
==
128
)
{
gemv_kernel_g128
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
// pointers
in_feats
,
kernel
,
zeros
,
scaling_factors
,
out_feats
,
// constants
num_in_channels
,
num_out_channels
);
}
return
_out_feats
;
;}
awq_cuda/quantization/gemv_cuda.h
0 → 100644
View file @
077f39a0
#pragma once
#include <torch/extension.h>
torch
::
Tensor
gemv_forward_cuda
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
group_size
);
examples/basic_quant.py
View file @
077f39a0
...
...
@@ -3,7 +3,7 @@ from transformers import AutoTokenizer
model_path
=
'lmsys/vicuna-7b-v1.5'
quant_path
=
'vicuna-7b-v1.5-awq'
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
}
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
,
"version"
:
"GEMM"
}
# Load model
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
)
...
...
examples/benchmark.py
0 → 100644
View file @
077f39a0
import
time
import
torch
import
argparse
import
numpy
as
np
import
pandas
as
pd
from
awq
import
AutoAWQForCausalLM
from
transformers
import
AutoTokenizer
from
torch.cuda
import
OutOfMemoryError
def
warmup
(
model
):
warm_up
=
torch
.
randn
((
4096
,
4096
)).
to
(
next
(
model
.
parameters
()).
device
)
torch
.
mm
(
warm_up
,
warm_up
)
def
generate
(
model
,
input_ids
,
n_generate
):
context_time
=
0
generate_time
=
[]
with
torch
.
inference_mode
():
for
i
in
range
(
n_generate
):
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
if
i
==
0
:
# prefill context
inputs
=
torch
.
as_tensor
(
input_ids
,
device
=
next
(
model
.
parameters
()).
device
)
else
:
# decode tokens
inputs
=
torch
.
as_tensor
(
token
,
device
=
next
(
model
.
parameters
()).
device
)
out
=
model
(
inputs
,
use_cache
=
True
)
torch
.
cuda
.
synchronize
()
token
=
out
[
0
][:,
-
1
].
max
(
1
)[
1
].
unsqueeze
(
1
)
if
i
==
0
:
context_time
+=
time
.
time
()
-
start
else
:
generate_time
.
append
(
time
.
time
()
-
start
)
return
context_time
,
generate_time
def
run_round
(
model_path
,
quant_file
,
n_generate
,
input_ids
,
batch_size
):
print
(
f
" -- Loading model..."
)
model
=
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_file
,
fuse_layers
=
True
,
max_new_tokens
=
n_generate
,
batch_size
=
batch_size
)
print
(
f
" -- Warming up..."
)
warmup
(
model
)
print
(
f
" -- Generating
{
n_generate
}
tokens,
{
input_ids
.
shape
[
1
]
}
in context..."
)
try
:
context_time
,
generate_time
=
generate
(
model
,
input_ids
,
n_generate
)
successful_generate
=
True
except
RuntimeError
as
ex
:
if
'cuda out of memory'
in
str
(
ex
).
lower
():
successful_generate
=
False
else
:
raise
RuntimeError
(
ex
)
device
=
next
(
model
.
parameters
()).
device
memory_used
=
torch
.
cuda
.
max_memory_allocated
(
device
)
/
(
1024
**
3
)
memory_pct
=
memory_used
/
(
torch
.
cuda
.
get_device_properties
(
device
).
total_memory
/
(
1024
**
3
))
*
100
if
successful_generate
:
# number of tokens in context / time for processing context * batch size
prefill_tokens_per_second
=
input_ids
.
shape
[
1
]
/
context_time
*
batch_size
# 1 second / median time per token in seconds * batch size
decode_tokens_per_second
=
1
/
np
.
median
(
generate_time
)
*
batch_size
print
(
f
" ** Speed (Prefill):
{
prefill_tokens_per_second
:.
2
f
}
tokens/second"
)
print
(
f
" ** Speed (Decode):
{
decode_tokens_per_second
:.
2
f
}
tokens/second"
)
print
(
f
" ** Max Memory (VRAM):
{
memory_used
:.
2
f
}
GB (
{
memory_pct
:.
2
f
}
%)"
)
else
:
prefill_tokens_per_second
=
'OOM'
decode_tokens_per_second
=
'OOM'
return
{
"Batch Size"
:
batch_size
,
"Prefill Length"
:
input_ids
.
shape
[
1
],
"Decode Length"
:
n_generate
,
"Prefill tokens/s"
:
prefill_tokens_per_second
,
"Decode tokens/s"
:
decode_tokens_per_second
,
"Memory (VRAM)"
:
f
"
{
memory_used
:.
2
f
}
GB (
{
memory_pct
:.
2
f
}
%)"
},
model
.
quant_config
[
"version"
]
def
main
(
args
):
rounds
=
[
{
"context"
:
32
,
"n_generate"
:
32
},
{
"context"
:
64
,
"n_generate"
:
64
},
{
"context"
:
128
,
"n_generate"
:
128
},
{
"context"
:
256
,
"n_generate"
:
256
},
{
"context"
:
512
,
"n_generate"
:
512
},
{
"context"
:
1024
,
"n_generate"
:
1024
},
{
"context"
:
2048
,
"n_generate"
:
2048
},
]
all_stats
=
[]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_path
,
trust_remote_code
=
True
)
for
settings
in
rounds
:
input_ids
=
torch
.
randint
(
0
,
tokenizer
.
vocab_size
,
(
args
.
batch_size
,
settings
[
"context"
])).
cuda
()
stats
,
model_version
=
run_round
(
args
.
model_path
,
args
.
quant_file
,
settings
[
"n_generate"
],
input_ids
,
args
.
batch_size
)
all_stats
.
append
(
stats
)
if
stats
[
"Prefill tokens/s"
]
==
'OOM'
:
break
df
=
pd
.
DataFrame
(
all_stats
)
print
(
'GPU:'
,
torch
.
cuda
.
get_device_name
())
print
(
'Model:'
,
args
.
model_path
)
print
(
'Version:'
,
model_version
)
print
(
df
.
to_markdown
(
index
=
False
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
default
=
"casperhansen/vicuna-7b-v1.5-awq"
,
help
=
"path to the model"
)
parser
.
add_argument
(
"--quant_file"
,
type
=
str
,
default
=
"awq_model_w4_g128.pt"
,
help
=
"weights filename"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"weights filename"
)
args
=
parser
.
parse_args
()
main
(
args
)
\ No newline at end of file
setup.py
View file @
077f39a0
...
...
@@ -9,7 +9,7 @@ os.environ["CC"] = "g++"
os
.
environ
[
"CXX"
]
=
"g++"
common_setup_kwargs
=
{
"version"
:
"0.0.
1
"
,
"version"
:
"0.0.
2
"
,
"name"
:
"autoawq"
,
"author"
:
"Casper Hansen"
,
"license"
:
"MIT"
,
...
...
@@ -44,14 +44,28 @@ requirements = [
"toml"
,
"attributedict"
,
"protobuf"
,
"torchvision"
"torchvision"
,
"tabulate"
]
include_dirs
=
[]
def
get_include_dirs
():
include_dirs
=
[]
conda_cuda_include_dir
=
os
.
path
.
join
(
get_python_lib
(),
"nvidia/cuda_runtime/include"
)
if
os
.
path
.
isdir
(
conda_cuda_include_dir
):
include_dirs
.
append
(
conda_cuda_include_dir
)
conda_cuda_include_dir
=
os
.
path
.
join
(
get_python_lib
(),
"nvidia/cuda_runtime/include"
)
if
os
.
path
.
isdir
(
conda_cuda_include_dir
):
include_dirs
.
append
(
conda_cuda_include_dir
)
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
include_dirs
.
append
(
this_dir
)
return
include_dirs
def
get_generator_flag
():
generator_flag
=
[]
torch_dir
=
torch
.
__path__
[
0
]
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
return
generator_flag
def
check_dependencies
():
if
CUDA_HOME
is
None
:
...
...
@@ -77,6 +91,8 @@ def get_compute_capabilities():
return
capability_flags
check_dependencies
()
include_dirs
=
get_include_dirs
()
generator_flags
=
get_generator_flag
()
arch_flags
=
get_compute_capabilities
()
if
os
.
name
==
"nt"
:
...
...
@@ -86,8 +102,21 @@ if os.name == "nt":
}
else
:
extra_compile_args
=
{
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
],
"nvcc"
:
[
"-O3"
,
"-std=c++17"
]
+
arch_flags
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
,
"-DENABLE_BF16"
],
"nvcc"
:
[
"-O3"
,
"-std=c++17"
,
"-DENABLE_BF16"
,
"-U__CUDA_NO_HALF_OPERATORS__"
,
"-U__CUDA_NO_HALF_CONVERSIONS__"
,
"-U__CUDA_NO_BFLOAT16_OPERATORS__"
,
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
,
"-U__CUDA_NO_BFLOAT162_OPERATORS__"
,
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__"
,
"--expt-relaxed-constexpr"
,
"--expt-extended-lambda"
,
"--use_fast_math"
,
]
+
arch_flags
+
generator_flags
}
extensions
=
[
...
...
@@ -97,7 +126,10 @@ extensions = [
"awq_cuda/pybind.cpp"
,
"awq_cuda/quantization/gemm_cuda_gen.cu"
,
"awq_cuda/layernorm/layernorm.cu"
,
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
,
"awq_cuda/quantization/gemv_cuda.cu"
,
"awq_cuda/attention/ft_attention.cpp"
,
"awq_cuda/attention/decoder_masked_multihead_attention.cu"
],
extra_compile_args
=
extra_compile_args
)
]
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment