Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6a2d7c9f
"git@developer.sourcefind.cn:cnjsdfcy/simbricks.git" did not exist on "f71d3cfc7b2747a95d2305cf12af8dfe7e9a3e33"
Commit
6a2d7c9f
authored
Sep 26, 2023
by
danyao12
Browse files
fwd mqa/gqa
parent
c459f488
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
101 additions
and
49 deletions
+101
-49
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
...e_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
..._softmax_gemm/run_batched_multihead_attention_forward.inc
+28
-18
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+32
-17
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
+1
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+1
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+1
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
+12
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+1
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+1
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
+23
-7
No files found.
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
View file @
6a2d7c9f
...
...
@@ -75,7 +75,7 @@ static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecia
static
constexpr
auto
TensorSpecB0
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecB1
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
TensorSpecC
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
bool
Deterministic
=
tru
e
;
static
constexpr
bool
Deterministic
=
fals
e
;
#if(DIM <= 32)
using
DeviceGemmInstance
=
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
View file @
6a2d7c9f
...
...
@@ -18,7 +18,8 @@ int run(int argc, char* argv[])
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
G0
=
7
;
ck
::
index_t
G1
=
13
;
ck
::
index_t
G1
=
12
;
// h_q
ck
::
index_t
G2
=
12
;
// h_kv
bool
input_permute
=
false
;
bool
output_permute
=
true
;
...
...
@@ -37,7 +38,7 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
1
3
)
else
if
(
argc
==
1
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
...
...
@@ -49,20 +50,21 @@ int run(int argc, char* argv[])
O
=
std
::
stoi
(
argv
[
7
]);
G0
=
std
::
stoi
(
argv
[
8
]);
G1
=
std
::
stoi
(
argv
[
9
]);
G2
=
std
::
stoi
(
argv
[
10
]);
p_drop
=
std
::
stof
(
argv
[
1
0
]);
p_drop
=
std
::
stof
(
argv
[
1
1
]);
input_permute
=
std
::
stoi
(
argv
[
1
1
]);
output_permute
=
std
::
stoi
(
argv
[
1
2
]);
input_permute
=
std
::
stoi
(
argv
[
1
2
]);
output_permute
=
std
::
stoi
(
argv
[
1
3
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 1
1
: M, N, K, O, G0, G1
\n
"
);
printf
(
"arg1
0
:
scale (alpha)
\n
"
);
printf
(
"arg1
1
to 1
2
: input / output permute
\n
"
);
printf
(
"arg4 to 1
0
: M, N, K, O, G0, G1
, G2
\n
"
);
printf
(
"arg1
1
:
p_drop
\n
"
);
printf
(
"arg1
2
to 1
3
: input / output permute
\n
"
);
exit
(
0
);
}
...
...
@@ -77,17 +79,17 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G
1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
K
,
K
,
G
1
*
K
,
1
}
// B0 layout [G0, N, G
1
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G
1
, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
K
,
K
,
G
2
*
K
,
1
}
// B0 layout [G0, N, G
2
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G
2
, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G
1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
O
,
O
,
1
,
G
1
*
O
}
// B1 layout [G0, N, G
1
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G
1
, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
O
,
O
,
1
,
G
2
*
O
}
// B1 layout [G0, N, G
2
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G
2
, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
...
...
@@ -286,11 +288,19 @@ int run(int argc, char* argv[])
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
(
G1
/
G2
);
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g2
,
idx
[
2
],
idx
[
1
]);
});
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
(
G1
/
G2
);
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g2
,
idx
[
2
],
idx
[
1
]);
});
z_gs_ms_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
z_g_m_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
6a2d7c9f
...
...
@@ -11,6 +11,7 @@ int run(int argc, char* argv[])
bool
output_permute
=
true
;
float
p_drop
=
0.2
;
int
h_ratio
=
1
;
// G1 / G2
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
...
...
@@ -24,22 +25,25 @@ int run(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
7
)
else
if
(
argc
==
8
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
p_drop
=
std
::
stoi
(
argv
[
4
]);
input_permute
=
std
::
stoi
(
argv
[
5
]);
output_permute
=
std
::
stoi
(
argv
[
6
]);
h_ratio
=
std
::
stof
(
argv
[
5
]);
input_permute
=
std
::
stoi
(
argv
[
6
]);
output_permute
=
std
::
stoi
(
argv
[
7
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 5: input / output permute
\n
"
);
printf
(
"arg4: p_drop
\n
"
);
printf
(
"arg5: h_ratio
\n
"
);
printf
(
"arg6 to 7: input / output permute
\n
"
);
exit
(
0
);
}
...
...
@@ -88,7 +92,8 @@ int run(int argc, char* argv[])
int
K
=
DIM
;
int
O
=
DIM
;
int
G0
=
rand
()
%
3
+
1
;
int
G1
=
rand
()
%
5
+
1
;
int
G2
=
rand
()
%
5
+
1
;
int
G1
=
G2
*
h_ratio
;
g0_g1_m_n_k_o
.
push_back
({
G0
,
G1
,
M
,
N
,
K
,
O
});
...
...
@@ -98,17 +103,17 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
}
// A layout [G0, M, G1, K]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
K
,
M
*
K
,
K
,
1
};
// A layout [G0, G1, M, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G
1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G
2
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
K
,
K
,
G
1
*
K
,
1
}
// B0 layout [G0, N, G
1
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G
1
, N, K]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
K
,
K
,
G
2
*
K
,
1
}
// B0 layout [G0, N, G
2
, K]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
K
,
N
*
K
,
K
,
1
};
// B0 layout [G0, G
2
, N, K]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G
1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G
2
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
=
input_permute
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
1
*
O
,
O
,
1
,
G
1
*
O
}
// B1 layout [G0, N, G
1
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
1
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G
1
, N, O]
?
std
::
vector
<
ck
::
index_t
>
{
N
*
G
2
*
O
,
O
,
1
,
G
2
*
O
}
// B1 layout [G0, N, G
2
, O]
:
std
::
vector
<
ck
::
index_t
>
{
G
2
*
N
*
O
,
N
*
O
,
1
,
O
};
// B1 layout [G0, G
2
, N, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
=
...
...
@@ -253,7 +258,8 @@ int run(int argc, char* argv[])
acc0_element_op
,
b1_element_op
,
c_element_op
,
p_drop
,
// dropout ratio
p_drop
,
// dropout ratio
h_ratio
,
{
seed
,
offset
});
// dropout random seed and offset, offset should be
// at least the number of elements on a thread
...
...
@@ -296,7 +302,8 @@ int run(int argc, char* argv[])
acc0_element_op
,
b1_element_op
,
c_element_op
,
p_drop
,
// dropout ratio
p_drop
,
// dropout ratio
h_ratio
,
{
seed
,
offset
});
// dropout random seed and offset, offset should be
// at least the number of elements on a thread
// specify workspace for problem_desc
...
...
@@ -350,11 +357,19 @@ int run(int argc, char* argv[])
a_gs_ms_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
a_g_m_k
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
2
],
idx
[
3
])
=
self
(
idx
);
});
b0_gs_ns_ks
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b0_g_k_n
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b0_g_k_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
h_ratio
;
self
(
idx
)
=
b0_gs_ns_ks
(
g0
,
g2
,
idx
[
2
],
idx
[
1
]);
});
b1_gs_os_ns
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
b1_g_n_o
(
idx
[
0
]
*
G1
+
idx
[
1
],
idx
[
3
],
idx
[
2
])
=
self
(
idx
);
b1_g_n_o
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
const
size_t
&
g0
=
idx
[
0
]
/
G1
;
const
size_t
&
g1
=
idx
[
0
]
%
G1
;
const
size_t
&
g2
=
g1
/
h_ratio
;
self
(
idx
)
=
b1_gs_os_ns
(
g0
,
g2
,
idx
[
2
],
idx
[
1
]);
});
z_gs_ms_ns_device_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
View file @
6a2d7c9f
...
...
@@ -134,6 +134,7 @@ struct DeviceGroupedMultiheadAttentionForward : public BaseOperator
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
6a2d7c9f
...
...
@@ -254,6 +254,7 @@ __global__ void
ignore
=
ygrad_grid_desc_o0_m_o1
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
h_ratio
;
ignore
=
nblock
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
6a2d7c9f
...
...
@@ -255,6 +255,7 @@ __global__ void
ignore
=
ygrad_grid_desc_m0_o_m1
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
h_ratio
;
ignore
=
nblock
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_xdl_cshuffle_v2.hpp
View file @
6a2d7c9f
...
...
@@ -78,6 +78,7 @@ __global__ void
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
h_ratio
,
const
index_t
mblock
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
...
...
@@ -94,13 +95,14 @@ __global__ void
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
gkv_idx
=
__builtin_amdgcn_readfirstlane
(
g_idx
/
h_ratio
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g
kv
_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g
kv
_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
...
...
@@ -211,6 +213,7 @@ __global__ void
ignore
=
lse_grid_desc_m
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
h_ratio
;
ignore
=
mblock
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
...
...
@@ -662,7 +665,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_gs_gemm1ns_gemm1ks_strides
[
NumDimG
+
NumDimO
+
NumDimN
-
1
]},
c_mz_gemm1nz_strides_
{
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
-
1
],
c_gs_ms_gemm1ns_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)}
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)},
h_ratio_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)
/
b_grid_desc_g_n_k_
.
GetLength
(
I0
)}
{
// TODO ANT: implement bias addition
ignore
=
p_acc1_biases
;
...
...
@@ -736,10 +740,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
std
::
cout
<<
"d0_grid_desc_g_m_n_: "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I0
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I1
)
<<
", "
<<
d0_grid_desc_g_m_n_
.
GetLength
(
I2
)
<<
'\n'
;
std
::
cout
<<
"d0_grid_desc_m_n_: "
<<
d0_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
d0_grid_desc_m_n_
.
GetLength
(
I1
)
<<
'\n'
;
std
::
cout
<<
"b1_grid_desc_g_n_k_: "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I0
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I1
)
<<
", "
<<
b1_grid_desc_g_n_k_
.
GetLength
(
I2
)
<<
'\n'
;
...
...
@@ -802,6 +804,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
std
::
vector
<
index_t
>
c_mz_gemm1nz_strides_
;
index_t
batch_count_
;
index_t
h_ratio_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float
p_dropout_
;
...
...
@@ -900,6 +903,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
arg
.
lse_grid_desc_m_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
h_ratio_
,
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
),
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
...
...
@@ -1014,12 +1018,14 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
b_g
=
arg
.
b_grid_desc_g_n_k_
.
GetLength
(
I0
);
const
index_t
c_m
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
&&
b_g
<=
c_g
))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
6a2d7c9f
...
...
@@ -209,6 +209,7 @@ __global__ void
#else
ignore
=
group_kernel_args
;
ignore
=
group_count
;
ignore
=
h_ratio
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
acc_element_op
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
6a2d7c9f
...
...
@@ -208,6 +208,7 @@ __global__ void
#else
ignore
=
group_kernel_args
;
ignore
=
group_count
;
ignore
=
h_ratio
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
acc_element_op
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
View file @
6a2d7c9f
...
...
@@ -44,6 +44,7 @@ __global__ void
kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v2
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
group_kernel_args
,
const
index_t
group_count
,
const
index_t
h_ratio
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
...
...
@@ -88,13 +89,14 @@ __global__ void
const
index_t
num_blocks_per_batch
=
arg_ptr
[
group_id
].
num_blocks_per_batch_
;
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
(
block_id
-
arg_ptr
[
group_id
].
block_start_
)
/
(
Deterministic
?
1
:
num_blocks_per_batch
));
const
index_t
gkv_idx
=
__builtin_amdgcn_readfirstlane
(
g_idx
/
h_ratio
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetBBasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetBBasePtr
(
g
kv
_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetB1BasePtr
(
g
kv
_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
...
...
@@ -194,6 +196,7 @@ __global__ void
#else
ignore
=
group_kernel_args
;
ignore
=
group_count
;
ignore
=
h_ratio
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
acc_element_op
;
...
...
@@ -415,7 +418,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MakeD0GridDescriptor_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_biases_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
);
}
...
...
@@ -424,7 +426,6 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
MakeD0GridDescriptor_G_M_N
(
const
std
::
vector
<
ck
::
index_t
>&
acc0_biases_gs_ms_ns_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
acc0_biases_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_biases_gs_ms_ns_lengths
,
acc0_biases_gs_ms_ns_strides
);
}
...
...
@@ -655,6 +656,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
// for gridwise gemm check
CGridDesc_M_N
c_grid_desc_m_n_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
// raw data
std
::
vector
<
ck
::
index_t
>
d0_n_length_stride_
;
...
...
@@ -679,12 +682,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
:
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
}
c_element_op_
{
c_element_op
},
h_ratio_
{
h_ratio
}
{
ignore
=
p_acc1_biases_vec
;
// TODO ANT: implement bias addition
...
...
@@ -855,6 +860,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
-
1
],
problem_desc
.
c_gs_ms_os_strides
[
NumDimG
+
NumDimM
+
NumDimO
-
1
]},
c_grid_desc_m_n
,
b_grid_desc_g_n_k
,
c_grid_desc_g_m_n
,
d0_n_length_stride
});
}
...
...
@@ -880,6 +887,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation
b1_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
h_ratio_
;
float
p_dropout_
;
uint8_t
p_dropout_in_uint8_t_
;
unsigned
long
long
seed_
;
...
...
@@ -969,6 +977,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
group_count_
,
arg
.
h_ratio_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
acc_element_op_
,
...
...
@@ -1091,11 +1100,14 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
const
auto
&
device_arg
=
arg
.
group_device_args_
[
i
];
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
device_arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
b_g
=
device_arg
.
b_grid_desc_g_n_k_
.
GetLength
(
I0
);
const
index_t
c_m
=
device_arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
);
const
index_t
c_gemm1n
=
device_arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
a_m
=
kernel_arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_gemm1n
=
kernel_arg
.
b1_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
);
if
(
!
(
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
))
if
(
!
(
c_m
==
a_m
&&
c_gemm1n
==
b1_gemm1n
&&
c_g
%
b_g
==
0
&&
c_g
/
b_g
==
arg
.
h_ratio_
))
{
return
false
;
}
...
...
@@ -1203,6 +1215,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
{
return
Argument
{
p_a_vec
,
...
...
@@ -1220,6 +1233,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_element_op
,
c_element_op
,
p_dropout
,
h_ratio
,
seeds
};
}
...
...
@@ -1242,6 +1256,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
index_t
h_ratio
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a_vec
,
...
...
@@ -1259,6 +1274,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
b1_element_op
,
c_element_op
,
p_dropout
,
h_ratio
,
seeds
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment