Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c2170a5b
Unverified
Commit
c2170a5b
authored
Nov 18, 2024
by
Angus Wang
Committed by
GitHub
Nov 18, 2024
Browse files
[Kernel] Explicitly specify other value in tl.load calls (#9014)
Signed-off-by:
Angus Wang
<
wangjadehao@gmail.com
>
parent
6b2d25ef
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
35 additions
and
14 deletions
+35
-14
vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py
...ops/blocksparse_attention/blocksparse_attention_kernel.py
+10
-3
vllm/lora/ops/bgmv_expand.py
vllm/lora/ops/bgmv_expand.py
+3
-1
vllm/lora/ops/bgmv_expand_slice.py
vllm/lora/ops/bgmv_expand_slice.py
+7
-1
vllm/lora/ops/sgmv_expand.py
vllm/lora/ops/sgmv_expand.py
+4
-1
vllm/lora/ops/sgmv_expand_slice.py
vllm/lora/ops/sgmv_expand_slice.py
+4
-1
vllm/model_executor/layers/quantization/awq_triton.py
vllm/model_executor/layers/quantization/awq_triton.py
+7
-7
No files found.
vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py
View file @
c2170a5b
...
...
@@ -157,19 +157,22 @@ def _fwd_kernel_inner(
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
,
mask
=
offs_n
[
None
,
:]
+
start_n
<
k_seqlen
,
other
=
0.0
,
)
else
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
,
mask
=
(
offs_n
[
None
,
:]
+
start_n
<
k_seqlen
)
&
(
offs_d
[:,
None
]
<
D_HEAD
),
other
=
0.0
,
)
else
:
if
EVEN_D
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
)
else
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kt
,
mask
=
offs_d
[:,
None
]
<
D_HEAD
)
mask
=
offs_d
[:,
None
]
<
D_HEAD
,
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M_LOADING
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
...
...
@@ -200,19 +203,22 @@ def _fwd_kernel_inner(
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
,
mask
=
offs_n
[:,
None
]
+
start_n
<
k_seqlen
,
other
=
0.0
,
)
else
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
,
mask
=
(
offs_n
[:,
None
]
+
start_n
<
k_seqlen
)
&
(
offs_d
[
None
,
:]
<
D_HEAD
),
other
=
0.0
,
)
else
:
if
EVEN_D
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
)
else
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vt
,
mask
=
offs_d
[
None
,
:]
<
D_HEAD
)
mask
=
offs_d
[
None
,
:]
<
D_HEAD
,
other
=
0.0
)
acc
+=
tl
.
dot
(
p
,
v
)
...
...
@@ -318,12 +324,13 @@ def _fwd_kernel_batch_inference(
q
=
tl
.
load
(
Q
+
offs_m
[:,
None
]
*
stride_qt
+
offs_d
[
None
,
:]
*
stride_qd
,
mask
=
offs_m
[:,
None
]
<
q_seqlen
,
other
=
0.0
,
)
else
:
q
=
tl
.
load
(
Q
+
offs_m
[:,
None
]
*
stride_qt
+
offs_d
[
None
,
:]
*
stride_qd
,
mask
=
(
offs_m
[:,
None
]
<
q_seqlen
)
&
(
offs_d
[
None
,
:]
<
D_HEAD
),
other
=
0
,
other
=
0
.0
,
)
sparse_crow_ptr
=
(
layout_crow_ptr
+
off_h
*
layout_crow_stride_h
+
...
...
vllm/lora/ops/bgmv_expand.py
View file @
c2170a5b
...
...
@@ -75,7 +75,9 @@ def _bgmv_expand_kernel(
other
=
0.0
,
)
# [BLOCK_N,BLOCK_K]
if
ADD_INPUTS
:
tiled_out
=
tl
.
load
(
c_ptr
+
current_n
*
cn_stride
,
mask
=
c_mask
)
tiled_out
=
tl
.
load
(
c_ptr
+
current_n
*
cn_stride
,
mask
=
c_mask
,
other
=
0.0
)
accumulator
=
tl
.
sum
(
tiled_a
*
tiled_b
,
1
)
+
tiled_out
else
:
accumulator
=
tl
.
sum
(
tiled_a
*
tiled_b
,
1
)
...
...
vllm/lora/ops/bgmv_expand_slice.py
View file @
c2170a5b
...
...
@@ -78,7 +78,13 @@ def _bgmv_expand_slice_kernel(
)
# [BLOCK_N,BLOCK_K]
if
ADD_INPUTS
:
tiled_out
=
tl
.
load
(
c_ptr
+
current_n
*
cn_stride
,
mask
=
c_mask
)
# explicitly pass in other=None to tell triton that masked values
# can be uninitialized. This is OK because the later tl.store
# operation uses the same mask, eliminating the risk of garbage
# values propagating
tiled_out
=
tl
.
load
(
c_ptr
+
current_n
*
cn_stride
,
mask
=
c_mask
,
other
=
None
)
accumulator
=
tl
.
sum
(
tiled_a
*
tiled_b
,
1
)
+
tiled_out
else
:
accumulator
=
tl
.
sum
(
tiled_a
*
tiled_b
,
1
)
...
...
vllm/lora/ops/sgmv_expand.py
View file @
c2170a5b
...
...
@@ -88,7 +88,10 @@ def _sgmv_expand_kernel(
c_mask
=
(
offset_cm
[:,
None
]
<
(
cur_seq_start
+
M
))
&
(
offset_cn
[
None
,
:]
<
N
)
if
ADD_INPUTS
:
tiled_out
=
tl
.
load
(
c_ptr
,
mask
=
c_mask
)
# explicitly pass in other=None to tell triton that masked values
# can be uninitialized. This is OK because the later tl.store operation
# uses the same mask, eliminating the risk of garbage values propagating
tiled_out
=
tl
.
load
(
c_ptr
,
mask
=
c_mask
,
other
=
None
)
tiled_c
+=
tiled_out
tl
.
store
(
c_ptr
,
tiled_c
,
mask
=
c_mask
)
...
...
vllm/lora/ops/sgmv_expand_slice.py
View file @
c2170a5b
...
...
@@ -94,7 +94,10 @@ def _sgmv_expand_slice_kernel(
c_mask
=
(
offset_cm
[:,
None
]
<
(
cur_seq_start
+
M
))
&
(
offset_cn
[
None
,
:]
<
(
slice_offset
+
N
))
if
ADD_INPUTS
:
tiled_out
=
tl
.
load
(
c_ptr
,
mask
=
c_mask
)
# explicitly pass in other=None to tell triton that masked values
# can be uninitialized. This is OK because the later tl.store operation
# uses the same mask, eliminating the risk of garbage values propagating
tiled_out
=
tl
.
load
(
c_ptr
,
mask
=
c_mask
,
other
=
None
)
tiled_c
+=
tiled_out
tl
.
store
(
c_ptr
,
tiled_c
,
mask
=
c_mask
)
...
...
vllm/model_executor/layers/quantization/awq_triton.py
View file @
c2170a5b
...
...
@@ -42,7 +42,7 @@ def awq_dequantize_kernel(
result_masks
=
result_masks_y
[:,
None
]
&
result_masks_x
[
None
,
:]
# Load the weights.
iweights
=
tl
.
load
(
qweight_ptr
+
offsets
,
masks
)
iweights
=
tl
.
load
(
qweight_ptr
+
offsets
,
masks
,
0.0
)
iweights
=
tl
.
interleave
(
iweights
,
iweights
)
iweights
=
tl
.
interleave
(
iweights
,
iweights
)
iweights
=
tl
.
interleave
(
iweights
,
iweights
)
...
...
@@ -71,7 +71,7 @@ def awq_dequantize_kernel(
zero_masks
=
zero_masks_y
[:,
None
]
&
zero_masks_x
[
None
,
:]
# Load the zeros.
zeros
=
tl
.
load
(
zeros_ptr
+
zero_offsets
,
zero_masks
)
zeros
=
tl
.
load
(
zeros_ptr
+
zero_offsets
,
zero_masks
,
0.0
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
...
...
@@ -91,7 +91,7 @@ def awq_dequantize_kernel(
scale_masks
=
scale_masks_y
[:,
None
]
&
scale_masks_x
[
None
,
:]
# Load the scales.
scales
=
tl
.
load
(
scales_ptr
+
scale_offsets
,
scale_masks
)
scales
=
tl
.
load
(
scales_ptr
+
scale_offsets
,
scale_masks
,
0.0
)
scales
=
tl
.
broadcast_to
(
scales
,
(
BLOCK_SIZE_Y
,
BLOCK_SIZE_X
*
8
))
# Dequantize.
...
...
@@ -165,10 +165,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
*
SPLIT_K
)):
masks_k
=
offsets_k
<
K
masks_a
=
masks_am
[:,
None
]
&
masks_k
[
None
,
:]
a
=
tl
.
load
(
a_ptrs
,
mask
=
masks_a
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
masks_a
,
other
=
0.0
)
masks_b
=
masks_k
[:,
None
]
&
masks_bn
[
None
,
:]
b
=
tl
.
load
(
b_ptrs
,
mask
=
masks_b
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
masks_b
,
other
=
0.0
)
b
=
tl
.
interleave
(
b
,
b
)
b
=
tl
.
interleave
(
b
,
b
)
b
=
tl
.
interleave
(
b
,
b
)
...
...
@@ -181,7 +181,7 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
masks_zk
=
offsets_szk
<
K
//
group_size
masks_z
=
masks_zk
[:,
None
]
&
masks_zn
[
None
,
:]
zeros_ptrs
=
zeros_ptr
+
offsets_z
zeros
=
tl
.
load
(
zeros_ptrs
,
mask
=
masks_z
)
zeros
=
tl
.
load
(
zeros_ptrs
,
mask
=
masks_z
,
other
=
0.0
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
zeros
=
tl
.
interleave
(
zeros
,
zeros
)
...
...
@@ -191,7 +191,7 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
masks_sk
=
offsets_szk
<
K
//
group_size
masks_s
=
masks_sk
[:,
None
]
&
masks_sn
[
None
,
:]
scales_ptrs
=
scales_ptr
+
offsets_s
scales
=
tl
.
load
(
scales_ptrs
,
mask
=
masks_s
)
scales
=
tl
.
load
(
scales_ptrs
,
mask
=
masks_s
,
other
=
0.0
)
scales
=
tl
.
broadcast_to
(
scales
,
(
BLOCK_SIZE_K
,
BLOCK_SIZE_N
))
b
=
(
b
>>
shifts
)
&
0xF
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment