Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c2510944
Commit
c2510944
authored
Oct 27, 2023
by
danyao12
Browse files
mqa/gqa inference
parent
5ff2d646
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
157 additions
and
108 deletions
+157
-108
example/52_flash_atten_bias/run_batched_multihead_attention_bias_infer.inc
...atten_bias/run_batched_multihead_attention_bias_infer.inc
+60
-49
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_infer.inc
...atten_bias/run_grouped_multihead_attention_bias_infer.inc
+66
-51
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
...gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
+11
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp
...gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp
+20
-4
No files found.
example/52_flash_atten_bias/run_batched_multihead_attention_bias_infer.inc
View file @
c2510944
...
...
@@ -14,11 +14,12 @@ int run(int argc, char* argv[])
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
G0
=
7
;
ck
::
index_t
G1
=
13
;
// Output shape C[G0, M, G1Q, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1q_m_o = reshape(C_g_m_o, [g0, g1q, m, o])
// C_g0_m_g1q_o = permute(C_g0_g1q_m_o, [0, 2, 1, 3])
ck
::
index_t
G0
=
7
;
ck
::
index_t
G1Q
=
12
;
// h_q
ck
::
index_t
G1KV
=
12
;
// h_kv
float
alpha
=
1
;
...
...
@@ -35,64 +36,65 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
1
3
)
else
if
(
argc
==
1
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
M
=
std
::
stoi
(
argv
[
4
]);
N
=
std
::
stoi
(
argv
[
5
]);
K
=
std
::
stoi
(
argv
[
6
]);
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1Q
=
std
::
stoi
(
argv
[
9
]);
G1KV
=
std
::
stoi
(
argv
[
10
]);
alpha
=
std
::
stof
(
argv
[
1
0
]);
alpha
=
std
::
stof
(
argv
[
1
1
]);
input_permute
=
std
::
stoi
(
argv
[
1
1
]);
output_permute
=
std
::
stoi
(
argv
[
1
2
]);
input_permute
=
std
::
stoi
(
argv
[
1
2
]);
output_permute
=
std
::
stoi
(
argv
[
1
3
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 1
1
: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg1
0
: scale (alpha)
\n
"
);
printf
(
"arg1
1
to 1
2
: input / output permute
\n
"
);
printf
(
"arg4 to 1
0
: M, N, K, O, G0, G1
Q, G1KV
\n
"
);
printf
(
"arg1
1
: scale (alpha)
\n
"
);
printf
(
"arg1
2
to 1
3
: input / output permute
\n
"
);
exit
(
0
);
}
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// A layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
KV
*
K
,
K
,
G1
KV
*
K
,
1
}
// B0 layout [G0, N, G1
KV
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1
KV
, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
KV
*
O
,
O
,
1
,
G1
KV
*
O
}
// B1 layout [G0, N, G1
KV
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1
KV
, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// C layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// D0 layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// D0 layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// D0 layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// D0 layout [G0, G1
Q
, M, N]
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
...
...
@@ -188,14 +190,15 @@ int run(int argc, char* argv[])
return
0
;
}
ck
::
index_t
BatchCount
=
G0
*
G1
;
ck
::
index_t
BatchCount
=
G0
*
G1
Q
;
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
+
size_t
(
M
)
*
N
)
*
BatchCount
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
)
*
BatchCount
;
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
+
size_t
(
M
)
*
N
)
*
BatchCount
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
sizeof
(
Acc0BiasDataType
)
*
M
*
N
)
*
BatchCount
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
...
...
@@ -211,23 +214,31 @@ int run(int argc, char* argv[])
Tensor
<
ADataType
>
a_g_m_k
({
BatchCount
,
M
,
K
});
Tensor
<
B0DataType
>
b0_g_k_n
({
BatchCount
,
K
,
N
});
Tensor
<
B1DataType
>
b1_g_n_o
({
BatchCount
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
AccDataType
>
acc0_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after gemm0
Tensor
<
Acc0BiasDataType
>
d0_g_m_n
({
BatchCount
,
M
,
N
});
Tensor
<
ADataType
>
a1_g_m_n
({
BatchCount
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
BatchCount
,
M
,
O
});
// scratch object after gemm1
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
a_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
(
G1Q
/
G1KV
);
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
d0_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
(
G1Q
/
G1KV
);
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
b1
_gs_
o
s_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1
_g_
n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
d0
_gs_
m
s_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0
_g_
m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// gemm 0
...
...
@@ -267,10 +278,10 @@ int run(int argc, char* argv[])
// permute
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
...
...
example/52_flash_atten_bias/run_grouped_multihead_attention_bias_infer.inc
View file @
c2510944
...
...
@@ -10,6 +10,8 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
output_permute
=
true
;
int
h_ratio
=
1
;
// G1Q / G1KV
if
(
argc
==
1
)
{
// use default case
...
...
@@ -20,21 +22,23 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
6
)
else
if
(
argc
==
7
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
input_permute
=
std
::
stoi
(
argv
[
4
]);
output_permute
=
std
::
stoi
(
argv
[
5
]);
h_ratio
=
std
::
stof
(
argv
[
4
]);
input_permute
=
std
::
stoi
(
argv
[
5
]);
output_permute
=
std
::
stoi
(
argv
[
6
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 5: input / output permute
\n
"
);
printf
(
"arg4: h_ratio
\n
"
);
printf
(
"arg5 to 6: input / output permute
\n
"
);
exit
(
0
);
}
...
...
@@ -49,7 +53,7 @@ int run(int argc, char* argv[])
std
::
vector
<
const
void
*>
p_d0
;
std
::
vector
<
const
void
*>
p_b1
;
std
::
vector
<
void
*>
p_c
;
std
::
vector
<
std
::
vector
<
int
>>
g0_g1_m_n_k_o
;
std
::
vector
<
std
::
vector
<
int
>>
g0_g1
q
_m_n_k_o
;
std
::
vector
<
Tensor
<
ADataType
>>
a_tensors
;
std
::
vector
<
Tensor
<
B0DataType
>>
b0_tensors
;
...
...
@@ -69,44 +73,47 @@ int run(int argc, char* argv[])
std
::
cout
<<
"group count "
<<
group_count
<<
". printing first 4 groups
\n
"
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
128
*
(
rand
()
%
8
+
1
);
int
N
=
128
*
(
rand
()
%
8
+
1
);
int
K
=
40
;
int
O
=
40
*
(
rand
()
%
2
+
1
);
int
G0
=
rand
()
%
3
+
1
;
int
G1
=
rand
()
%
5
+
1
;
int
M
=
128
*
(
rand
()
%
8
+
1
);
int
N
=
128
*
(
rand
()
%
8
+
1
);
int
K
=
40
;
int
O
=
40
*
(
rand
()
%
2
+
1
);
int
G0
=
rand
()
%
3
+
1
;
int
G1KV
=
rand
()
%
5
+
1
;
int
G1Q
=
G1KV
*
h_ratio
;
g0_g1_m_n_k_o
.
push_back
({
G0
,
G1
,
M
,
N
,
K
,
O
});
g0_g1
q
_m_n_k_o
.
push_back
({
G0
,
G1
Q
,
M
,
N
,
K
,
O
});
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
Q
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
K
,
K
,
G1
Q
*
K
,
1
}
// A layout [G0, M, G1
Q
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1
Q
, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
KV
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// B0 layout [G0, N, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
K
,
K
,
G1KV
*
K
,
1
}
// B0 layout [G0, N, G1KV, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G1KV, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
KV
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
}
// B1 layout [G0, N, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G1KV
*
O
,
O
,
1
,
G1KV
*
O
}
// B1 layout [G0, N, G1KV, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1KV
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G1KV, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
Q
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
O
,
O
,
G1
Q
*
O
,
1
}
// C layout [G0, M, G1
Q
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1
Q
, M, O]
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
Q
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// d0 layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// d0 layout [G0, G1, M, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
Q
*
N
,
N
,
G1
Q
*
N
,
1
}
// d0 layout [G0, M, G1
Q
, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
Q
*
M
*
N
,
M
*
N
,
N
,
1
};
// d0 layout [G0, G1
Q
, M, N]
problem_descs
.
push_back
({
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
...
...
@@ -128,7 +135,7 @@ int run(int argc, char* argv[])
Tensor
<
B1DataType
>
b1_gs_os_ns
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
int
Batch
=
G0
*
G1
;
int
Batch
=
G0
*
G1
Q
;
flop
+=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
+
size_t
(
M
)
*
N
)
*
Batch
;
num_byte
+=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
...
...
@@ -248,12 +255,12 @@ int run(int argc, char* argv[])
{
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
const
int
&
G0
=
g0_g1_m_n_k_o
[
i
][
0
];
const
int
&
G1
=
g0_g1_m_n_k_o
[
i
][
1
];
const
int
&
M
=
g0_g1_m_n_k_o
[
i
][
2
];
const
int
&
N
=
g0_g1_m_n_k_o
[
i
][
3
];
const
int
&
K
=
g0_g1_m_n_k_o
[
i
][
4
];
const
int
&
O
=
g0_g1_m_n_k_o
[
i
][
5
];
const
int
&
G0
=
g0_g1
q
_m_n_k_o
[
i
][
0
];
const
int
&
G1
Q
=
g0_g1
q
_m_n_k_o
[
i
][
1
];
const
int
&
M
=
g0_g1
q
_m_n_k_o
[
i
][
2
];
const
int
&
N
=
g0_g1
q
_m_n_k_o
[
i
][
3
];
const
int
&
K
=
g0_g1
q
_m_n_k_o
[
i
][
4
];
const
int
&
O
=
g0_g1
q
_m_n_k_o
[
i
][
5
];
const
auto
&
c_gs_ms_os_lengths
=
problem_descs
[
i
]
.
c_gs_ms_os_lengths
;
const
auto
&
c_gs_ms_os_strides
=
problem_descs
[
i
]
.
c_gs_ms_os_strides
;
...
...
@@ -267,27 +274,35 @@ int run(int argc, char* argv[])
c_gs_ms_os_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
Tensor
<
ADataType
>
a_g_m_k
({
G0
*
G1
,
M
,
K
});
Tensor
<
B0DataType
>
b0_g_k_n
({
G0
*
G1
,
K
,
N
});
Tensor
<
Acc0BiasDataType
>
d0_g_m_n
({
G0
*
G1
,
M
,
N
});
Tensor
<
B1DataType
>
b1_g_n_o
({
G0
*
G1
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after gemm0
Tensor
<
ADataType
>
a1_g_m_n
({
G0
*
G1
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
G0
*
G1
,
M
,
O
});
// scratch object after gemm1
Tensor
<
ADataType
>
a_g_m_k
({
G0
*
G1
Q
,
M
,
K
});
Tensor
<
B0DataType
>
b0_g_k_n
({
G0
*
G1
Q
,
K
,
N
});
Tensor
<
Acc0BiasDataType
>
d0_g_m_n
({
G0
*
G1
Q
,
M
,
N
});
Tensor
<
B1DataType
>
b1_g_n_o
({
G0
*
G1
Q
,
N
,
O
});
Tensor
<
AccDataType
>
acc0_g_m_n
({
G0
*
G1
Q
,
M
,
N
});
// scratch object after gemm0
Tensor
<
ADataType
>
a1_g_m_n
({
G0
*
G1
Q
,
M
,
N
});
// scratch object after softmax
Tensor
<
CDataType
>
c_g_m_o_host_result
({
G0
*
G1
Q
,
M
,
O
});
// scratch object after gemm1
Tensor
<
CDataType
>
c_gs_ms_os_host_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
// permute
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
a_g_m_k
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
h_ratio
;
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
d0_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1Q
;
const
size_t
&
g1q
=
idx
[
0
]
%
G1Q
;
const
size_t
&
g1kv
=
g1q
/
h_ratio
;
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g1kv
,
idx
[
2
],
idx
[
1
]);
});
b1
_gs_
o
s_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1
_g_
n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
d0
_gs_
m
s_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
d0
_g_
m_n
(
idx
[
0
]
*
G1
Q
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
// gemm 0
...
...
@@ -331,10 +346,10 @@ int run(int argc, char* argv[])
// permute
c_gs_ms_os_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
=
idx
[
1
];
const
size_t
&
g0
=
idx
[
0
];
const
size_t
&
g1
q
=
idx
[
1
];
const
size_t
g
=
g0
*
G1
+
g1
;
const
size_t
g
=
g0
*
G1
Q
+
g1
q
;
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_infer_xdl_cshuffle.hpp
View file @
c2510944
...
...
@@ -64,6 +64,7 @@ __global__ void
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
h_ratio
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
)
{
...
...
@@ -73,13 +74,14 @@ __global__ void
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
gkv_idx
=
__builtin_amdgcn_readfirstlane
(
g_idx
/
h_ratio
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g
kv
_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g
kv
_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
...
...
@@ -130,6 +132,7 @@ __global__ void
ignore
=
d0_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
h_ratio
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
...
...
@@ -512,7 +515,8 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
b1_gs_gemm1ns_gemm1ks_strides
[
NumDimG
+
NumDimO
+
NumDimN
-
1
]},
c_mz_gemm1nz_strides_
{
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c1_grid_desc_g_m_n_
.
GetLength
(
I0
)}
batch_count_
{
c1_grid_desc_g_m_n_
.
GetLength
(
I0
)},
h_ratio_
{
c1_grid_desc_g_m_n_
.
GetLength
(
I0
)
/
b_grid_desc_g_n_k_
.
GetLength
(
I0
)}
{
// TODO ANT: implement bias addition
ignore
=
p_acc1_bias
;
...
...
@@ -613,6 +617,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
std
::
vector
<
ck
::
index_t
>
d0s_nl_ns_lengths_strides_
;
index_t
batch_count_
;
index_t
h_ratio_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
// raw data
...
...
@@ -683,6 +688,7 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
arg
.
d0_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
h_ratio_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
);
};
...
...
@@ -730,12 +736,13 @@ struct DeviceBatchedMultiheadAttentionInfer_Xdl_CShuffle
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
arg
.
c1_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
b_g
=
arg
.
b_grid_desc_g_n_k_
.
GetLength
(
I0
);
const
index_t
c_m
=
arg
.
c1_grid_desc_m_n_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
arg
.
c1_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_infer_xdl_cshuffle.hpp
View file @
c2510944
...
...
@@ -39,6 +39,7 @@ __global__ void
kernel_grouped_multiple_head_flash_attention_infer
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
h_ratio
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
...
...
@@ -76,13 +77,14 @@ __global__ void
const
index_t
num_blocks_per_batch
=
arg_ptr
[
group_id
].
num_blocks_per_batch_
;
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
(
block_id
-
arg_ptr
[
group_id
].
block_start_
)
/
num_blocks_per_batch
);
const
index_t
gkv_idx
=
__builtin_amdgcn_readfirstlane
(
g_idx
/
h_ratio
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetBBasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetBBasePtr
(
g
kv
_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g
kv
_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
...
...
@@ -118,6 +120,7 @@ __global__ void
#else
ignore
=
group_kernel_args
;
ignore
=
group_count
;
ignore
=
h_ratio
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
acc_element_op
;
...
...
@@ -495,6 +498,8 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
// for gridwise gemm check
C1GridDesc_M_N
c1_grid_desc_m_n_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
C1GridDesc_G_M_N
c1_grid_desc_g_m_n_
;
// raw data
std
::
vector
<
ck
::
index_t
>
d0_n_length_stride_
;
...
...
@@ -536,6 +541,9 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
grid_size_
=
0
;
h_ratio_
=
problem_desc_vec
[
0
].
a_gs_ms_ks_lengths
[
NumDimG
-
1
]
/
problem_desc_vec
[
0
].
b0_gs_ns_ks_lengths
[
NumDimG
-
1
];
for
(
std
::
size_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
const
auto
p_a_grid
=
static_cast
<
const
ADataType
*>
(
p_a_vec
[
i
]);
...
...
@@ -648,6 +656,8 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
{
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
c_grid_desc_m_n
,
b_grid_desc_g_n_k
,
c1_grid_desc_g_m_n
,
d0_n_length_stride
});
}
}
...
...
@@ -663,6 +673,8 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
AccElementwiseOperation
acc_element_op_
;
B1ElementwiseOperation
b1_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
h_ratio_
;
};
// Invoker
...
...
@@ -739,6 +751,7 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
group_count_
,
arg
.
h_ratio_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
acc_element_op_
,
...
...
@@ -797,11 +810,14 @@ struct DeviceGroupedMultiheadAttentionInfer_Xdl_CShuffle
const
auto
&
device_arg
=
arg
.
group_device_args_
[
i
];
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
device_arg
.
c1_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
b_g
=
device_arg
.
b_grid_desc_g_n_k_
.
GetLength
(
I0
);
const
index_t
c_m
=
device_arg
.
c1_grid_desc_m_n_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
device_arg
.
c1_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
a_m
=
kernel_arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
if
(
!
(
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
if
(
!
(
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
&&
c_g
/
b_g
==
arg
.
h_ratio_
))
{
return
false
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment