Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4e79cc4b
Commit
4e79cc4b
authored
Jan 28, 2023
by
ltqin
Browse files
save z matrix
parent
cb914a54
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
23 deletions
+37
-23
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16.cpp
...atched_gemm_scale_softmax_gemm_permute_train_xdl_bf16.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16_dropout.cpp
...emm/batched_multihead_attention_backward_fp16_dropout.cpp
+16
-15
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
+19
-6
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_train_xdl_bf16.cpp
100755 → 100644
View file @
4e79cc4b
...
@@ -32,7 +32,7 @@ template <ck::index_t... Is>
...
@@ -32,7 +32,7 @@ template <ck::index_t... Is>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
4e79cc4b
...
@@ -479,7 +479,7 @@ int run(int argc, char* argv[])
...
@@ -479,7 +479,7 @@ int run(int argc, char* argv[])
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
Scale
{
alpha
},
Scale
{
alpha
},
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{});
YElementOp
{});
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16_dropout.cpp
View file @
4e79cc4b
...
@@ -24,12 +24,13 @@ Kernel outputs:
...
@@ -24,12 +24,13 @@ Kernel outputs:
*/
*/
#define PRINT_HOST 0
#define PRINT_HOST 0
#define USING_MASK
1
#define USING_MASK
0
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
#include <initializer_list>
#include <initializer_list>
#include <cstdlib>
#include <cstdlib>
#include <fstream>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
...
@@ -259,12 +260,12 @@ int run(int argc, char* argv[])
...
@@ -259,12 +260,12 @@ int run(int argc, char* argv[])
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.2
;
float
p_drop
=
0.2
;
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
float
scale_rp_dropout
=
alpha
*
rp_dropout
;
float
scale_rp_dropout
=
alpha
*
rp_dropout
;
...
@@ -333,7 +334,6 @@ int run(int argc, char* argv[])
...
@@ -333,7 +334,6 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// Y layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// Y layout [G0, G1, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
std
::
vector
<
ck
::
index_t
>
z_gs_ms_ns_strides
=
input_permute
input_permute
...
@@ -475,7 +475,7 @@ int run(int argc, char* argv[])
...
@@ -475,7 +475,7 @@ int run(int argc, char* argv[])
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
ygrad_device_buf
.
ToDevice
(
ygrad_gs_ms_os
.
mData
.
data
());
kgrad_device_buf
.
SetZero
();
kgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
//z_device_buf.SetZero();
//
z_device_buf.SetZero();
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
...
@@ -509,11 +509,11 @@ int run(int argc, char* argv[])
...
@@ -509,11 +509,11 @@ int run(int argc, char* argv[])
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
QKVElementOp
{},
Scale
{
scale_rp_dropout
},
//dQ *= scale_rp_dropout
Scale
{
scale_rp_dropout
},
//
dQ *= scale_rp_dropout
QKVElementOp
{},
QKVElementOp
{},
YElementOp
{},
YElementOp
{},
p_drop
,
p_drop
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
(
seed
,
offset
));
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
...
@@ -543,13 +543,14 @@ int run(int argc, char* argv[])
...
@@ -543,13 +543,14 @@ int run(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
// copy z matirx data form device
std
::
ofstream
file
(
"./z_matrix_txt"
);
z_device_buf
.
FromDevice
(
z_g_m_n
.
mData
.
data
());
file
<<
z_g_m_n
<<
std
::
endl
;
// std::cout << "z_g_m_n ref:\n" << z_g_m_n;
bool
pass
=
true
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
//copy z matirx data form device
z_device_buf
.
FromDevice
(
z_g_m_n
.
mData
.
data
());
//std::cout << "z_g_m_n ref:\n" << z_g_m_n;
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
kgrad_device_buf
.
SetZero
();
// reset global accum buffer and rerun
vgrad_device_buf
.
SetZero
();
vgrad_device_buf
.
SetZero
();
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp
View file @
4e79cc4b
...
@@ -96,6 +96,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -96,6 +96,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I8
=
Number
<
8
>
{};
static
constexpr
auto
I9
=
Number
<
9
>
{};
static
constexpr
auto
WaveSize
=
64
;
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// K1 should be Number<...>
...
@@ -1483,7 +1485,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1483,7 +1485,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_id
=
GetGemm0WaveIdx
();
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
const
auto
wave_m_n_id
=
GetGemm0WaveMNIdx
(
wave_id
[
I2
]);
// I2: 0~63
/*
if(get_thread_global_1d_id() == 191)
if
(
get_thread_global_1d_id
()
==
191
)
{
{
printf
(
"wave_id{ %d, %d, %d}, wave_m_n_id{%d, %d}
\n
"
,
printf
(
"wave_id{ %d, %d, %d}, wave_m_n_id{%d, %d}
\n
"
,
wave_id
[
I0
],
wave_id
[
I0
],
...
@@ -1491,7 +1493,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1491,7 +1493,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
wave_id
[
I2
],
wave_id
[
I2
],
wave_m_n_id
[
I0
],
wave_m_n_id
[
I0
],
wave_m_n_id
[
I1
]);
wave_m_n_id
[
I1
]);
}*/
printf
(
"z grid descripter{%d, %d, %d, %d, %d, %d, %d, %d, %d, %d}
\n
"
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I0
),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I1
),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I2
),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I3
),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I4
),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I5
),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I6
),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I7
),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I8
),
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
.
GetLength
(
I9
));
}
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
z_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
ushort
,
ushort
,
ushort
,
ushort
,
...
@@ -1767,8 +1780,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1767,8 +1780,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
index_t
num_gemm1_k_block_outer_loop
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
)
/
NPerBlock
;
const
index_t
num_gemm1_k_block_outer_loop
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
)
/
NPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
const
index_t
K
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
)
*
k_grid_desc_k0_n_k1
.
GetLength
(
I2
);
const
index_t
K
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
)
*
k_grid_desc_k0_n_k1
.
GetLength
(
I2
);
const
float
scal
e
=
1.0
f
/
std
::
sqrt
(
K
);
const
float
scal
ar
=
1.0
f
/
std
::
sqrt
(
K
);
// Initialize dQ
// Initialize dQ
qgrad_thread_buf
.
Clear
();
qgrad_thread_buf
.
Clear
();
...
@@ -1849,14 +1862,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1849,14 +1862,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
}
}
else
else
{
{
s_slash_p_thread_buf
(
i
)
=
scal
e
*
s_slash_p_thread_buf
[
i
];
s_slash_p_thread_buf
(
i
)
=
scal
ar
*
s_slash_p_thread_buf
[
i
];
}
}
});
});
}
}
else
else
{
{
static_for
<
0
,
s_slash_p_thread_buf
.
Size
(),
1
>
{}(
static_for
<
0
,
s_slash_p_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
s_slash_p_thread_buf
(
i
)
=
scal
e
*
s_slash_p_thread_buf
[
i
];
});
[
&
](
auto
i
)
{
s_slash_p_thread_buf
(
i
)
=
scal
ar
*
s_slash_p_thread_buf
[
i
];
});
}
}
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment