Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9fce7bee
Unverified
Commit
9fce7bee
authored
Oct 20, 2025
by
Jiangyun Zhu
Committed by
GitHub
Oct 20, 2025
Browse files
[Kernel] Accelerate solve_tril with TMA (#26746)
Signed-off-by:
zjy0516
<
riverclouds.zhu@qq.com
>
parent
b63f2143
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
412 additions
and
301 deletions
+412
-301
vllm/model_executor/layers/fla/ops/op.py
vllm/model_executor/layers/fla/ops/op.py
+32
-11
vllm/model_executor/layers/fla/ops/solve_tril.py
vllm/model_executor/layers/fla/ops/solve_tril.py
+375
-290
vllm/model_executor/layers/fla/ops/utils.py
vllm/model_executor/layers/fla/ops/utils.py
+5
-0
No files found.
vllm/model_executor/layers/fla/ops/op.py
View file @
9fce7bee
...
@@ -11,29 +11,50 @@ import os
...
@@ -11,29 +11,50 @@ import os
from
vllm.triton_utils
import
tl
,
tldevice
,
triton
from
vllm.triton_utils
import
tl
,
tldevice
,
triton
from
.utils
import
is_gather_supported
if
os
.
environ
.
get
(
"FLA_USE_FAST_OPS"
,
"0"
)
==
"1"
:
if
os
.
environ
.
get
(
"FLA_USE_FAST_OPS"
,
"0"
)
==
"1"
:
div
=
tldevice
.
fast_dividef
exp
=
tldevice
.
fast_expf
exp
=
tldevice
.
fast_expf
log
=
tldevice
.
fast_logf
log
=
tldevice
.
fast_logf
log2
=
tldevice
.
fast_log2f
log2
=
tldevice
.
fast_log2f
else
:
else
:
@
triton
.
jit
def
div_normal
(
x
,
y
):
return
x
/
y
div
=
div_normal
exp
=
tl
.
exp
exp
=
tl
.
exp
log
=
tl
.
log
log
=
tl
.
log
log2
=
tl
.
log2
log2
=
tl
.
log2
if
not
hasattr
(
tl
,
"gather"
)
:
if
not
is_gather_supported
:
@
triton
.
jit
@
triton
.
jit
def
gather
(
src
,
index
,
axis
,
_builder
=
None
):
def
gather
(
src
,
index
,
axis
,
_builder
=
None
):
# This is a fallback implementation when tl.gather is not supported
"""
# In order to pass triton compiler, there is no actual gather operation
Gather operation that works when tl.gather is not supported.
return
src
This is a fallback implementation that returns None.
Just to make triton compiler happy.
"""
return
None
else
:
else
:
gather
=
tl
.
gather
gather
=
tl
.
gather
if
hasattr
(
triton
.
language
,
"_experimental_make_tensor_descriptor"
):
# For Triton 3.3.x
make_tensor_descriptor
=
triton
.
language
.
_experimental_make_tensor_descriptor
elif
hasattr
(
triton
.
language
,
"make_tensor_descriptor"
):
# For Triton 3.4.x and later
make_tensor_descriptor
=
triton
.
language
.
make_tensor_descriptor
else
:
"""
Fallback implementation when TMA is not supported.
Returns None to indicate TMA descriptors are unavailable.
Just make triton compiler happy.
"""
@
triton
.
jit
def
make_tensor_descriptor
(
base
,
shape
,
strides
,
block_shape
,
_builder
=
None
,
):
return
None
vllm/model_executor/layers/fla/ops/solve_tril.py
View file @
9fce7bee
...
@@ -8,12 +8,21 @@
...
@@ -8,12 +8,21 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
# ruff: noqa: E501
import
os
import
torch
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
.index
import
prepare_chunk_indices
from
.index
import
prepare_chunk_indices
from
.utils
import
input_guard
from
.op
import
make_tensor_descriptor
from
.utils
import
input_guard
,
is_amd
,
is_tma_supported
FLA_TRIL_PRECISION
=
os
.
environ
.
get
(
"FLA_TRIL_PRECISION"
,
"ieee"
)
ALLOWED_TRIL_PRECISIONS
=
[
"ieee"
,
"tf32"
]
if
is_amd
else
[
"ieee"
,
"tf32"
,
"tf32x3"
]
assert
FLA_TRIL_PRECISION
in
ALLOWED_TRIL_PRECISIONS
,
(
f
"FLA_TRIL_PRECISION must be one of
{
ALLOWED_TRIL_PRECISIONS
}
, but got
{
FLA_TRIL_PRECISION
}
"
)
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
...
@@ -28,13 +37,15 @@ from .utils import input_guard
...
@@ -28,13 +37,15 @@ from .utils import input_guard
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
solve_tril_16x16_kernel
(
def
solve_tril_16x16_kernel
(
A
,
A
,
A
d
,
A
i
,
cu_seqlens
,
cu_seqlens
,
chunk_indices
,
chunk_indices
,
T
,
T
,
H
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
USE_TMA
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
DOT_PRECISION
:
tl
.
constexpr
,
):
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
...
@@ -50,30 +61,43 @@ def solve_tril_16x16_kernel(
...
@@ -50,30 +61,43 @@ def solve_tril_16x16_kernel(
T
=
eos
-
bos
T
=
eos
-
bos
else
:
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
o_i
=
tl
.
arange
(
0
,
16
)
m_A
=
o_i
[:,
None
]
>
o_i
[
None
,
:]
m_I
=
o_i
[:,
None
]
==
o_i
[
None
,
:]
A
=
A
+
(
bos
*
H
+
i_h
)
*
BT
A
=
A
+
(
bos
*
H
+
i_h
)
*
BT
A
d
=
A
d
+
(
bos
*
H
+
i_h
)
*
16
A
i
=
A
i
+
(
bos
*
H
+
i_h
)
*
16
offset
=
(
i_t
*
16
)
%
BT
offset
=
(
i_t
*
16
)
%
BT
p_A
=
tl
.
make_block_ptr
(
if
not
USE_TMA
:
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
16
,
offset
),
(
16
,
16
),
(
1
,
0
)
p_A
=
tl
.
make_block_ptr
(
)
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
16
,
offset
),
(
16
,
16
),
(
1
,
0
)
p_Ai
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
16
,
0
),
(
16
,
16
),
(
1
,
0
))
)
b_A
=
tl
.
load
(
p_A
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
# [16, 16]
b_A
=
-
tl
.
where
(
tl
.
arange
(
0
,
16
)[:,
None
]
>
tl
.
arange
(
0
,
16
)[
None
,
:],
b_A
,
0
)
b_A
=
tl
.
load
(
p_A
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
else
:
desc
=
make_tensor_descriptor
(
A
,
[
T
,
BT
],
[
H
*
BT
,
1
],
[
16
,
16
])
desc_o
=
make_tensor_descriptor
(
Ai
,
[
T
,
16
],
[
H
*
16
,
1
],
[
16
,
16
])
b_A
=
desc
.
load
([
i_t
*
16
,
offset
]).
to
(
tl
.
float32
)
b_A
=
-
tl
.
where
(
m_A
,
b_A
,
0
)
o_i
=
tl
.
arange
(
0
,
16
)
for
i
in
range
(
2
,
min
(
16
,
T
-
i_t
*
16
)
):
for
i
in
range
(
1
,
min
(
16
,
T
-
i_t
*
16
)):
# [16]
b_a
=
-
tl
.
load
(
A
+
(
i_t
*
16
+
i
)
*
H
*
BT
+
o_i
+
offset
)
b_a
=
-
tl
.
load
(
A
+
(
i_t
*
16
+
i
)
*
H
*
BT
+
o_i
+
offset
)
b_a
=
b_a
+
tl
.
sum
(
b_a
[:,
None
]
*
b_A
,
0
)
b_a
=
b_a
+
tl
.
sum
(
b_a
[:,
None
]
*
b_A
,
0
)
mask
=
o_i
==
i
b_A
=
tl
.
where
((
o_i
==
i
)[:,
None
],
b_a
,
b_A
)
b_A
=
tl
.
where
(
mask
[:,
None
],
b_a
,
b_A
)
b_A
+=
m_I
b_A
+=
o_i
[:,
None
]
==
o_i
[
None
,
:]
if
not
USE_TMA
:
tl
.
store
(
p_Ai
=
tl
.
make_block_ptr
(
p_Ai
,
Ai
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
16
,
0
),
(
16
,
16
),
(
1
,
0
)
b_A
.
to
(
p_Ai
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
)
boundary_check
=
(
0
,
1
),
tl
.
store
(
)
p_Ai
,
b_A
.
to
(
p_Ai
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
else
:
desc_o
.
store
([
i_t
*
16
,
0
],
b_A
.
to
(
desc_o
.
dtype
,
fp_downcast_rounding
=
"rtne"
))
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
...
@@ -88,14 +112,15 @@ def solve_tril_16x16_kernel(
...
@@ -88,14 +112,15 @@ def solve_tril_16x16_kernel(
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
merge_16x16_to_32x32_inverse_kernel
(
def
merge_16x16_to_32x32_inverse_kernel
(
A
,
A
,
Ad
,
Ai
,
Ai
,
cu_seqlens
,
cu_seqlens
,
chunk_indices
,
chunk_indices
,
T
,
T
,
H
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
USE_TMA
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
DOT_PRECISION
:
tl
.
constexpr
,
):
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
...
@@ -112,51 +137,93 @@ def merge_16x16_to_32x32_inverse_kernel(
...
@@ -112,51 +137,93 @@ def merge_16x16_to_32x32_inverse_kernel(
else
:
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
A
+=
(
bos
*
H
+
i_h
)
*
32
o_i
=
tl
.
arange
(
0
,
16
)
Ad
+=
(
bos
*
H
+
i_h
)
*
16
m_A
=
o_i
[:,
None
]
>
o_i
[
None
,
:]
Ai
+=
(
bos
*
H
+
i_h
)
*
32
m_I
=
o_i
[:,
None
]
==
o_i
[
None
,
:]
A
+=
(
bos
*
H
+
i_h
)
*
BT
Ai
+=
(
bos
*
H
+
i_h
)
*
BT
p_A_21
=
tl
.
make_block_ptr
(
if
not
USE_TMA
:
A
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
p_A_11
=
tl
.
make_block_ptr
(
)
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
,
0
),
(
16
,
16
),
(
1
,
0
)
p_Ad_11
=
tl
.
make_block_ptr
(
)
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
32
,
0
),
(
16
,
16
),
(
1
,
0
)
p_A_22
=
tl
.
make_block_ptr
(
)
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
16
,
16
),
(
16
,
16
),
(
1
,
0
)
p_Ad_22
=
tl
.
make_block_ptr
(
)
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
32
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
b_Ai_11
=
tl
.
load
(
p_A_11
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
)
b_Ai_22
=
tl
.
load
(
p_A_22
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
p_Ai_11
=
tl
.
make_block_ptr
(
else
:
Ai
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
,
0
),
(
16
,
16
),
(
1
,
0
)
desc
=
make_tensor_descriptor
(
A
,
[
T
,
BT
],
[
H
*
BT
,
1
],
[
16
,
16
])
)
desc_o
=
make_tensor_descriptor
(
Ai
,
[
T
,
BT
],
[
H
*
BT
,
1
],
[
16
,
16
])
p_Ai_22
=
tl
.
make_block_ptr
(
b_Ai_11
=
desc
.
load
([
i_t
*
BT
+
0
,
0
]).
to
(
tl
.
float32
)
Ai
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
+
16
,
16
),
(
16
,
16
),
(
1
,
0
)
b_Ai_22
=
desc
.
load
([
i_t
*
BT
+
16
,
16
]).
to
(
tl
.
float32
)
)
p_Ai_21
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
)
A_21
=
tl
.
load
(
p_A_21
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
# [16, 16]
Ai_11
=
tl
.
load
(
p_Ad_11
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_Ai_11
=
-
tl
.
where
(
m_A
,
b_Ai_11
,
0
)
Ai_22
=
tl
.
load
(
p_Ad_22
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_Ai_22
=
-
tl
.
where
(
m_A
,
b_Ai_22
,
0
)
Ai_21
=
-
tl
.
dot
(
tl
.
dot
(
Ai_22
,
A_21
,
input_precision
=
"ieee"
),
Ai_11
,
input_precision
=
"ieee"
for
i
in
range
(
2
,
min
(
16
,
T
-
i_t
*
BT
)):
)
b_a_11
=
-
tl
.
load
(
A
+
(
i_t
*
BT
+
i
)
*
H
*
BT
+
o_i
)
tl
.
store
(
b_a_11
+=
tl
.
sum
(
b_a_11
[:,
None
]
*
b_Ai_11
,
0
)
p_Ai_11
,
b_Ai_11
=
tl
.
where
((
o_i
==
i
)[:,
None
],
b_a_11
,
b_Ai_11
)
Ai_11
.
to
(
p_Ai_11
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
for
i
in
range
(
16
+
2
,
min
(
32
,
T
-
i_t
*
BT
)):
boundary_check
=
(
0
,
1
),
b_a_22
=
-
tl
.
load
(
A
+
(
i_t
*
BT
+
i
)
*
H
*
BT
+
o_i
+
16
)
)
b_a_22
+=
tl
.
sum
(
b_a_22
[:,
None
]
*
b_Ai_22
,
0
)
tl
.
store
(
b_Ai_22
=
tl
.
where
((
o_i
==
i
-
16
)[:,
None
],
b_a_22
,
b_Ai_22
)
p_Ai_22
,
Ai_22
.
to
(
p_Ai_22
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
b_Ai_11
+=
m_I
boundary_check
=
(
0
,
1
),
b_Ai_22
+=
m_I
)
tl
.
store
(
if
not
USE_TMA
:
p_Ai_21
,
p_A_21
=
tl
.
make_block_ptr
(
Ai_21
.
to
(
p_Ai_21
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
boundary_check
=
(
0
,
1
),
)
b_A_21
=
tl
.
load
(
p_A_21
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
else
:
b_A_21
=
desc
.
load
([
i_t
*
BT
+
16
,
0
]).
to
(
tl
.
float32
)
b_Ai_21
=
-
tl
.
dot
(
tl
.
dot
(
b_Ai_22
,
b_A_21
,
input_precision
=
DOT_PRECISION
),
b_Ai_11
,
input_precision
=
DOT_PRECISION
,
)
)
if
not
USE_TMA
:
p_Ai_11
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_21
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_22
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
16
,
16
),
(
16
,
16
),
(
1
,
0
)
)
tl
.
store
(
p_Ai_11
,
b_Ai_11
.
to
(
p_Ai_11
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_22
,
b_Ai_22
.
to
(
p_Ai_22
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_21
,
b_Ai_21
.
to
(
p_Ai_21
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
else
:
desc_o
.
store
(
[
i_t
*
BT
+
0
,
0
],
b_Ai_11
.
to
(
desc_o
.
dtype
,
fp_downcast_rounding
=
"rtne"
)
)
desc_o
.
store
(
[
i_t
*
BT
+
16
,
0
],
b_Ai_21
.
to
(
desc_o
.
dtype
,
fp_downcast_rounding
=
"rtne"
)
)
desc_o
.
store
(
[
i_t
*
BT
+
16
,
16
],
b_Ai_22
.
to
(
desc_o
.
dtype
,
fp_downcast_rounding
=
"rtne"
)
)
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
@
triton
.
autotune
(
@
triton
.
autotune
(
...
@@ -170,14 +237,15 @@ def merge_16x16_to_32x32_inverse_kernel(
...
@@ -170,14 +237,15 @@ def merge_16x16_to_32x32_inverse_kernel(
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
merge_16x16_to_64x64_inverse_kernel
(
def
merge_16x16_to_64x64_inverse_kernel
(
A
,
A
,
Ad
,
Ai
,
Ai
,
cu_seqlens
,
cu_seqlens
,
chunk_indices
,
chunk_indices
,
T
,
T
,
H
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
USE_TMA
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
DOT_PRECISION
:
tl
.
constexpr
,
):
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
...
@@ -194,213 +262,245 @@ def merge_16x16_to_64x64_inverse_kernel(
...
@@ -194,213 +262,245 @@ def merge_16x16_to_64x64_inverse_kernel(
else
:
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
A
+=
(
bos
*
H
+
i_h
)
*
64
o_i
=
tl
.
arange
(
0
,
16
)
Ad
+=
(
bos
*
H
+
i_h
)
*
16
m_A
=
o_i
[:,
None
]
>
o_i
[
None
,
:]
Ai
+=
(
bos
*
H
+
i_h
)
*
64
m_I
=
o_i
[:,
None
]
==
o_i
[
None
,
:]
A
+=
(
bos
*
H
+
i_h
)
*
BT
Ai
+=
(
bos
*
H
+
i_h
)
*
BT
p_A_21
=
tl
.
make_block_ptr
(
if
not
USE_TMA
:
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
p_A_11
=
tl
.
make_block_ptr
(
)
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
,
0
),
(
16
,
16
),
(
1
,
0
)
p_A_32
=
tl
.
make_block_ptr
(
)
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
16
),
(
16
,
16
),
(
1
,
0
)
p_A_22
=
tl
.
make_block_ptr
(
)
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
16
,
16
),
(
16
,
16
),
(
1
,
0
)
p_A_31
=
tl
.
make_block_ptr
(
)
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
0
),
(
16
,
16
),
(
1
,
0
)
p_A_33
=
tl
.
make_block_ptr
(
)
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
32
,
32
),
(
16
,
16
),
(
1
,
0
)
p_A_43
=
tl
.
make_block_ptr
(
)
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
32
),
(
16
,
16
),
(
1
,
0
)
p_A_44
=
tl
.
make_block_ptr
(
)
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
48
,
48
),
(
16
,
16
),
(
1
,
0
)
p_A_42
=
tl
.
make_block_ptr
(
)
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
16
),
(
16
,
16
),
(
1
,
0
)
b_Ai_11
=
tl
.
load
(
p_A_11
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
)
b_Ai_22
=
tl
.
load
(
p_A_22
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
p_A_41
=
tl
.
make_block_ptr
(
b_Ai_33
=
tl
.
load
(
p_A_33
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
0
),
(
16
,
16
),
(
1
,
0
)
b_Ai_44
=
tl
.
load
(
p_A_44
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
)
else
:
p_Ad_11
=
tl
.
make_block_ptr
(
desc
=
make_tensor_descriptor
(
A
,
[
T
,
BT
],
[
H
*
BT
,
1
],
[
16
,
16
])
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
,
0
),
(
16
,
16
),
(
1
,
0
)
desc_o
=
make_tensor_descriptor
(
Ai
,
[
T
,
BT
],
[
H
*
BT
,
1
],
[
16
,
16
])
)
b_Ai_11
=
desc
.
load
([
i_t
*
BT
+
0
,
0
]).
to
(
tl
.
float32
)
p_Ad_22
=
tl
.
make_block_ptr
(
b_Ai_22
=
desc
.
load
([
i_t
*
BT
+
16
,
16
]).
to
(
tl
.
float32
)
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
b_Ai_33
=
desc
.
load
([
i_t
*
BT
+
32
,
32
]).
to
(
tl
.
float32
)
)
b_Ai_44
=
desc
.
load
([
i_t
*
BT
+
48
,
48
]).
to
(
tl
.
float32
)
p_Ad_33
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
+
32
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ad_44
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
+
48
,
0
),
(
16
,
16
),
(
1
,
0
)
)
A_21
=
tl
.
load
(
p_A_21
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
# [16, 16]
A_32
=
tl
.
load
(
p_A_32
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_Ai_11
=
-
tl
.
where
(
m_A
,
b_Ai_11
,
0
)
A_31
=
tl
.
load
(
p_A_31
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_Ai_22
=
-
tl
.
where
(
m_A
,
b_Ai_22
,
0
)
A_43
=
tl
.
load
(
p_A_43
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_Ai_33
=
-
tl
.
where
(
m_A
,
b_Ai_33
,
0
)
A_42
=
tl
.
load
(
p_A_42
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_Ai_44
=
-
tl
.
where
(
m_A
,
b_Ai_44
,
0
)
A_41
=
tl
.
load
(
p_A_41
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_11
=
tl
.
load
(
p_Ad_11
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
for
i
in
range
(
2
,
min
(
16
,
T
-
i_t
*
BT
)):
Ai_22
=
tl
.
load
(
p_Ad_22
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_a_11
=
-
tl
.
load
(
A
+
(
i_t
*
BT
+
i
)
*
H
*
BT
+
o_i
)
Ai_33
=
tl
.
load
(
p_Ad_33
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_a_11
+=
tl
.
sum
(
b_a_11
[:,
None
]
*
b_Ai_11
,
0
)
Ai_44
=
tl
.
load
(
p_Ad_44
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_Ai_11
=
tl
.
where
((
o_i
==
i
)[:,
None
],
b_a_11
,
b_Ai_11
)
for
i
in
range
(
16
+
2
,
min
(
32
,
T
-
i_t
*
BT
)):
b_a_22
=
-
tl
.
load
(
A
+
(
i_t
*
BT
+
i
)
*
H
*
BT
+
o_i
+
16
)
b_a_22
+=
tl
.
sum
(
b_a_22
[:,
None
]
*
b_Ai_22
,
0
)
b_Ai_22
=
tl
.
where
((
o_i
==
i
-
16
)[:,
None
],
b_a_22
,
b_Ai_22
)
for
i
in
range
(
32
+
2
,
min
(
48
,
T
-
i_t
*
BT
)):
b_a_33
=
-
tl
.
load
(
A
+
(
i_t
*
BT
+
i
)
*
H
*
BT
+
o_i
+
32
)
b_a_33
+=
tl
.
sum
(
b_a_33
[:,
None
]
*
b_Ai_33
,
0
)
b_Ai_33
=
tl
.
where
((
o_i
==
i
-
32
)[:,
None
],
b_a_33
,
b_Ai_33
)
for
i
in
range
(
48
+
2
,
min
(
64
,
T
-
i_t
*
BT
)):
b_a_44
=
-
tl
.
load
(
A
+
(
i_t
*
BT
+
i
)
*
H
*
BT
+
o_i
+
48
)
b_a_44
+=
tl
.
sum
(
b_a_44
[:,
None
]
*
b_Ai_44
,
0
)
b_Ai_44
=
tl
.
where
((
o_i
==
i
-
48
)[:,
None
],
b_a_44
,
b_Ai_44
)
b_Ai_11
+=
m_I
b_Ai_22
+=
m_I
b_Ai_33
+=
m_I
b_Ai_44
+=
m_I
Ai_21
=
-
tl
.
dot
(
if
not
USE_TMA
:
tl
.
dot
(
Ai_22
,
A_21
,
input_precision
=
"ieee"
),
Ai_11
,
input_precision
=
"ieee"
p_A_21
=
tl
.
make_block_ptr
(
)
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
Ai_32
=
-
tl
.
dot
(
)
tl
.
dot
(
Ai_33
,
A_32
,
input_precision
=
"ieee"
),
Ai_22
,
input_precision
=
"ieee"
p_A_31
=
tl
.
make_block_ptr
(
)
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
32
,
0
),
(
16
,
16
),
(
1
,
0
)
Ai_43
=
-
tl
.
dot
(
)
tl
.
dot
(
Ai_44
,
A_43
,
input_precision
=
"ieee"
),
Ai_33
,
input_precision
=
"ieee"
p_A_32
=
tl
.
make_block_ptr
(
)
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
32
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_A_41
=
tl
.
make_block_ptr
(
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
48
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_A_42
=
tl
.
make_block_ptr
(
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
48
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_A_43
=
tl
.
make_block_ptr
(
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
48
,
32
),
(
16
,
16
),
(
1
,
0
)
)
b_A_21
=
tl
.
load
(
p_A_21
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_A_31
=
tl
.
load
(
p_A_31
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_A_32
=
tl
.
load
(
p_A_32
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_A_41
=
tl
.
load
(
p_A_41
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_A_42
=
tl
.
load
(
p_A_42
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_A_43
=
tl
.
load
(
p_A_43
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
else
:
b_A_21
=
desc
.
load
([
i_t
*
BT
+
16
,
0
]).
to
(
tl
.
float32
)
b_A_31
=
desc
.
load
([
i_t
*
BT
+
32
,
0
]).
to
(
tl
.
float32
)
b_A_32
=
desc
.
load
([
i_t
*
BT
+
32
,
16
]).
to
(
tl
.
float32
)
b_A_41
=
desc
.
load
([
i_t
*
BT
+
48
,
0
]).
to
(
tl
.
float32
)
b_A_42
=
desc
.
load
([
i_t
*
BT
+
48
,
16
]).
to
(
tl
.
float32
)
b_A_43
=
desc
.
load
([
i_t
*
BT
+
48
,
32
]).
to
(
tl
.
float32
)
Ai_31
=
-
tl
.
dot
(
b_Ai_21
=
-
tl
.
dot
(
Ai_33
,
tl
.
dot
(
b_Ai_22
,
b_A_21
,
input_precision
=
DOT_PRECISION
),
tl
.
dot
(
A_31
,
Ai_11
,
input_precision
=
"ieee"
)
b_Ai_11
,
+
tl
.
dot
(
A_32
,
Ai_21
,
input_precision
=
"ieee"
),
input_precision
=
DOT_PRECISION
,
input_precision
=
"ieee"
,
)
)
Ai_42
=
-
tl
.
dot
(
b_Ai_32
=
-
tl
.
dot
(
Ai_44
,
tl
.
dot
(
b_Ai_33
,
b_A_32
,
input_precision
=
DOT_PRECISION
),
tl
.
dot
(
A_42
,
Ai_22
,
input_precision
=
"ieee"
)
b_Ai_22
,
+
tl
.
dot
(
A_43
,
Ai_32
,
input_precision
=
"ieee"
),
input_precision
=
DOT_PRECISION
,
input_precision
=
"ieee"
,
)
)
Ai_41
=
-
tl
.
dot
(
b_Ai_43
=
-
tl
.
dot
(
Ai_44
,
tl
.
dot
(
b_Ai_44
,
b_A_43
,
input_precision
=
DOT_PRECISION
),
tl
.
dot
(
A_41
,
Ai_11
,
input_precision
=
"ieee"
)
b_Ai_33
,
+
tl
.
dot
(
A_42
,
Ai_21
,
input_precision
=
"ieee"
)
input_precision
=
DOT_PRECISION
,
+
tl
.
dot
(
A_43
,
Ai_31
,
input_precision
=
"ieee"
),
input_precision
=
"ieee"
,
)
)
p_Ai_11
=
tl
.
make_block_ptr
(
b_Ai_31
=
-
tl
.
dot
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
0
),
(
16
,
16
),
(
1
,
0
)
b_Ai_33
,
)
tl
.
dot
(
b_A_31
,
b_Ai_11
,
input_precision
=
DOT_PRECISION
)
p_Ai_22
=
tl
.
make_block_ptr
(
+
tl
.
dot
(
b_A_32
,
b_Ai_21
,
input_precision
=
DOT_PRECISION
),
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
16
),
(
16
,
16
),
(
1
,
0
)
input_precision
=
DOT_PRECISION
,
)
)
p_Ai_33
=
tl
.
make_block_ptr
(
b_Ai_42
=
-
tl
.
dot
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
32
),
(
16
,
16
),
(
1
,
0
)
b_Ai_44
,
)
tl
.
dot
(
b_A_42
,
b_Ai_22
,
input_precision
=
DOT_PRECISION
)
p_Ai_44
=
tl
.
make_block_ptr
(
+
tl
.
dot
(
b_A_43
,
b_Ai_32
,
input_precision
=
DOT_PRECISION
),
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
48
),
(
16
,
16
),
(
1
,
0
)
input_precision
=
DOT_PRECISION
,
)
)
p_Ai_21
=
tl
.
make_block_ptr
(
b_Ai_41
=
-
tl
.
dot
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
b_Ai_44
,
)
tl
.
dot
(
b_A_41
,
b_Ai_11
,
input_precision
=
DOT_PRECISION
)
p_Ai_31
=
tl
.
make_block_ptr
(
+
tl
.
dot
(
b_A_42
,
b_Ai_21
,
input_precision
=
DOT_PRECISION
)
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
0
),
(
16
,
16
),
(
1
,
0
)
+
tl
.
dot
(
b_A_43
,
b_Ai_31
,
input_precision
=
DOT_PRECISION
),
)
input_precision
=
DOT_PRECISION
,
p_Ai_32
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_41
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_42
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_43
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
32
),
(
16
,
16
),
(
1
,
0
)
)
tl
.
store
(
p_Ai_11
,
Ai_11
.
to
(
p_Ai_11
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_22
,
Ai_22
.
to
(
p_Ai_22
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_33
,
Ai_33
.
to
(
p_Ai_33
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_44
,
Ai_44
.
to
(
p_Ai_44
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_21
,
Ai_21
.
to
(
p_Ai_21
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_31
,
Ai_31
.
to
(
p_Ai_31
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_32
,
Ai_32
.
to
(
p_Ai_32
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_41
,
Ai_41
.
to
(
p_Ai_41
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_42
,
Ai_42
.
to
(
p_Ai_42
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_43
,
Ai_43
.
to
(
p_Ai_43
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
)
fill_zeros
=
tl
.
zeros
((
16
,
16
),
dtype
=
tl
.
float32
)
if
not
USE_TMA
:
p_Ai_12
=
tl
.
make_block_ptr
(
p_Ai_11
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
16
),
(
16
,
16
),
(
1
,
0
)
Ai
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
,
0
),
(
16
,
16
),
(
1
,
0
)
)
)
p_Ai_13
=
tl
.
make_block_ptr
(
p_Ai_22
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
32
),
(
16
,
16
),
(
1
,
0
)
Ai
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
16
,
16
),
(
16
,
16
),
(
1
,
0
)
)
)
p_Ai_14
=
tl
.
make_block_ptr
(
p_Ai_33
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
48
),
(
16
,
16
),
(
1
,
0
)
Ai
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
32
,
32
),
(
16
,
16
),
(
1
,
0
)
)
)
p_Ai_23
=
tl
.
make_block_ptr
(
p_Ai_44
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
32
),
(
16
,
16
),
(
1
,
0
)
Ai
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
48
,
48
),
(
16
,
16
),
(
1
,
0
)
)
)
p_Ai_24
=
tl
.
make_block_ptr
(
p_Ai_21
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
48
),
(
16
,
16
),
(
1
,
0
)
Ai
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
)
)
p_Ai_34
=
tl
.
make_block_ptr
(
p_Ai_31
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
48
),
(
16
,
16
),
(
1
,
0
)
Ai
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
32
,
0
),
(
16
,
16
),
(
1
,
0
)
)
)
tl
.
store
(
p_Ai_32
=
tl
.
make_block_ptr
(
p_Ai_12
,
Ai
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
32
,
16
),
(
16
,
16
),
(
1
,
0
)
fill_zeros
.
to
(
p_Ai_12
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
)
boundary_check
=
(
0
,
1
),
p_Ai_41
=
tl
.
make_block_ptr
(
)
Ai
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
48
,
0
),
(
16
,
16
),
(
1
,
0
)
tl
.
store
(
)
p_Ai_13
,
p_Ai_42
=
tl
.
make_block_ptr
(
fill_zeros
.
to
(
p_Ai_13
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
Ai
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
48
,
16
),
(
16
,
16
),
(
1
,
0
)
boundary_check
=
(
0
,
1
),
)
)
p_Ai_43
=
tl
.
make_block_ptr
(
tl
.
store
(
Ai
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
48
,
32
),
(
16
,
16
),
(
1
,
0
)
p_Ai_14
,
)
fill_zeros
.
to
(
p_Ai_14
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
tl
.
store
(
boundary_check
=
(
0
,
1
),
p_Ai_11
,
)
b_Ai_11
.
to
(
p_Ai_11
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
tl
.
store
(
boundary_check
=
(
0
,
1
),
p_Ai_23
,
)
fill_zeros
.
to
(
p_Ai_23
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
tl
.
store
(
boundary_check
=
(
0
,
1
),
p_Ai_22
,
)
b_Ai_22
.
to
(
p_Ai_22
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
tl
.
store
(
boundary_check
=
(
0
,
1
),
p_Ai_24
,
)
fill_zeros
.
to
(
p_Ai_24
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
tl
.
store
(
boundary_check
=
(
0
,
1
),
p_Ai_33
,
)
b_Ai_33
.
to
(
p_Ai_33
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
tl
.
store
(
boundary_check
=
(
0
,
1
),
p_Ai_34
,
)
fill_zeros
.
to
(
p_Ai_34
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
tl
.
store
(
boundary_check
=
(
0
,
1
),
p_Ai_44
,
)
b_Ai_44
.
to
(
p_Ai_44
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_21
,
b_Ai_21
.
to
(
p_Ai_21
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_31
,
b_Ai_31
.
to
(
p_Ai_31
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_32
,
b_Ai_32
.
to
(
p_Ai_32
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_41
,
b_Ai_41
.
to
(
p_Ai_41
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_42
,
b_Ai_42
.
to
(
p_Ai_42
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_43
,
b_Ai_43
.
to
(
p_Ai_43
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
else
:
desc_o
.
store
(
[
i_t
*
BT
+
0
,
0
],
b_Ai_11
.
to
(
desc_o
.
dtype
,
fp_downcast_rounding
=
"rtne"
)
)
desc_o
.
store
(
[
i_t
*
BT
+
16
,
16
],
b_Ai_22
.
to
(
desc_o
.
dtype
,
fp_downcast_rounding
=
"rtne"
)
)
desc_o
.
store
(
[
i_t
*
BT
+
32
,
32
],
b_Ai_33
.
to
(
desc_o
.
dtype
,
fp_downcast_rounding
=
"rtne"
)
)
desc_o
.
store
(
[
i_t
*
BT
+
48
,
48
],
b_Ai_44
.
to
(
desc_o
.
dtype
,
fp_downcast_rounding
=
"rtne"
)
)
desc_o
.
store
(
[
i_t
*
BT
+
16
,
0
],
b_Ai_21
.
to
(
desc_o
.
dtype
,
fp_downcast_rounding
=
"rtne"
)
)
desc_o
.
store
(
[
i_t
*
BT
+
32
,
0
],
b_Ai_31
.
to
(
desc_o
.
dtype
,
fp_downcast_rounding
=
"rtne"
)
)
desc_o
.
store
(
[
i_t
*
BT
+
32
,
16
],
b_Ai_32
.
to
(
desc_o
.
dtype
,
fp_downcast_rounding
=
"rtne"
)
)
desc_o
.
store
(
[
i_t
*
BT
+
48
,
0
],
b_Ai_41
.
to
(
desc_o
.
dtype
,
fp_downcast_rounding
=
"rtne"
)
)
desc_o
.
store
(
[
i_t
*
BT
+
48
,
16
],
b_Ai_42
.
to
(
desc_o
.
dtype
,
fp_downcast_rounding
=
"rtne"
)
)
desc_o
.
store
(
[
i_t
*
BT
+
48
,
32
],
b_Ai_43
.
to
(
desc_o
.
dtype
,
fp_downcast_rounding
=
"rtne"
)
)
@
input_guard
@
input_guard
...
@@ -410,62 +510,47 @@ def solve_tril(
...
@@ -410,62 +510,47 @@ def solve_tril(
output_dtype
:
torch
.
dtype
=
torch
.
float
,
output_dtype
:
torch
.
dtype
=
torch
.
float
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Compute the inverse of the
lower triangular
matrix
Compute the inverse of the matrix
I + A
A should be strictly lower triangular, i.e., A.triu() == 0.
A should be strictly lower triangular, i.e., A.triu() == 0.
Args:
Args:
A (torch.Tensor):
A (torch.Tensor):
[B, T, H,
K]
[B, T, H,
BT], where BT should only be 16, 32, or 64.
cu_seqlens (torch.Tensor):
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
The cumulative sequence lengths of the input tensor. Default: `None`.
Default: None.
output_dtype (torch.dtype):
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float`
The dtype of the output tensor. Default: `torch.float`.
If `None`, the output dtype will be the same as the input dtype.
Returns:
Returns:
(I + A)^-1 with the same shape as A
(I + A)^-1 with the same shape as A
"""
"""
assert
A
.
shape
[
-
1
]
in
[
16
,
32
,
64
]
assert
A
.
shape
[
-
1
]
in
[
16
,
32
,
64
]
output_dtype
=
A
.
dtype
if
output_dtype
is
None
else
output_dtype
B
,
T
,
H
,
BT
=
A
.
shape
B
,
T
,
H
,
BT
=
A
.
shape
Ad
=
torch
.
empty
(
B
,
T
,
H
,
16
,
device
=
A
.
device
,
dtype
=
torch
.
float
if
BT
!=
16
else
output_dtype
)
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
16
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
len
(
chunk_indices
)
if
cu_seqlens
is
not
None
else
triton
.
cdiv
(
T
,
16
)
solve_tril_16x16_kernel
[
NT
,
B
*
H
](
A
=
A
,
Ad
=
Ad
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
BT
=
BT
,
)
if
BT
==
16
:
return
Ad
Ai
=
torch
.
empty
(
B
,
T
,
H
,
BT
,
device
=
A
.
device
,
dtype
=
output_dtype
)
merge_fn
=
(
merge_16x16_to_32x32_inverse_kernel
if
BT
==
32
else
merge_16x16_to_64x64_inverse_kernel
)
chunk_indices
=
(
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
)
)
NT
=
len
(
chunk_indices
)
if
cu_seqlens
is
not
None
else
triton
.
cdiv
(
T
,
BT
)
NT
=
len
(
chunk_indices
)
if
cu_seqlens
is
not
None
else
triton
.
cdiv
(
T
,
BT
)
Ai
=
torch
.
zeros_like
(
A
,
dtype
=
output_dtype
)
if
BT
==
16
:
merge_fn
=
solve_tril_16x16_kernel
elif
BT
==
32
:
merge_fn
=
merge_16x16_to_32x32_inverse_kernel
elif
BT
==
64
:
merge_fn
=
merge_16x16_to_64x64_inverse_kernel
merge_fn
[
NT
,
B
*
H
](
merge_fn
[
NT
,
B
*
H
](
A
=
A
,
A
=
A
,
Ad
=
Ad
,
Ai
=
Ai
,
Ai
=
Ai
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
chunk_indices
=
chunk_indices
,
T
=
T
,
T
=
T
,
H
=
H
,
H
=
H
,
BT
=
BT
,
BT
=
BT
,
USE_TMA
=
is_tma_supported
,
DOT_PRECISION
=
FLA_TRIL_PRECISION
,
)
)
return
Ai
return
Ai
vllm/model_executor/layers/fla/ops/utils.py
View file @
9fce7bee
...
@@ -150,6 +150,11 @@ is_nvidia_hopper = is_nvidia and (
...
@@ -150,6 +150,11 @@ is_nvidia_hopper = is_nvidia and (
or
torch
.
cuda
.
get_device_capability
()[
0
]
>=
9
or
torch
.
cuda
.
get_device_capability
()[
0
]
>=
9
)
)
use_cuda_graph
=
is_nvidia
and
os
.
environ
.
get
(
"FLA_USE_CUDA_GRAPH"
,
"0"
)
==
"1"
use_cuda_graph
=
is_nvidia
and
os
.
environ
.
get
(
"FLA_USE_CUDA_GRAPH"
,
"0"
)
==
"1"
is_gather_supported
=
hasattr
(
triton
.
language
,
"gather"
)
is_tma_supported
=
(
is_nvidia
and
torch
.
cuda
.
get_device_capability
(
0
)[
0
]
>=
9
)
and
(
hasattr
(
triton
.
language
,
"_experimental_make_tensor_descriptor"
)
or
hasattr
(
triton
.
language
,
"make_tensor_descriptor"
)
)
def
get_all_max_shared_mem
():
def
get_all_max_shared_mem
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment