Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3fb903fa
Commit
3fb903fa
authored
Dec 02, 2021
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into gemm_activation
parents
a8b539da
2cbb8976
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
10 deletions
+25
-10
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+3
-2
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+5
-4
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+3
-3
composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp
...ernel/include/utility/static_buffer_of_vector_type_v2.hpp
+13
-0
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+1
-1
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
3fb903fa
...
@@ -38,7 +38,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -38,7 +38,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
StaticBufferOfVectorTypeV2
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBufferOfVectorTypeV2
<
AddressSpaceEnum_t
::
Vgpr
,
vector_type
<
FloatAcc
,
16
>
,
vector_type
<
FloatAcc
,
xdlops_gemm
.
GetRegSizePerXdlops
()
>
,
MRepeat
*
NRepeat
,
MRepeat
*
NRepeat
,
true
>
true
>
c_thread_buf_
;
c_thread_buf_
;
...
@@ -136,7 +136,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -136,7 +136,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
3fb903fa
...
@@ -552,6 +552,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -552,6 +552,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// main body
// main body
index_t
k0_block_data_begin
=
0
;
index_t
k0_block_data_begin
=
0
;
c_thread_buf
.
Clear
();
if
constexpr
(
HasMainKBlockLoop
)
if
constexpr
(
HasMainKBlockLoop
)
{
{
do
do
...
@@ -591,6 +593,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -591,6 +593,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// output: register to global memory
// output: register to global memory
{
{
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
...
@@ -603,10 +608,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -603,10 +608,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
M0
>
{},
Number
<
N0
>
{},
I1
,
I1
,
Number
<
M2
>
{},
I1
,
Number
<
M4
>
{},
I1
));
// calculate origin of thread output tensor on global memory
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
const
auto
c_thread_mtx_on_block
=
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
3fb903fa
...
@@ -507,7 +507,7 @@ struct MfmaSelector
...
@@ -507,7 +507,7 @@ struct MfmaSelector
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
__host__
__device__
static
constexpr
void
mfma_check
()
__host__
__device__
constexpr
MfmaSelector
()
{
{
static_assert
(
selected_mfma
.
group_size
*
selected_mfma
.
num_groups_per_blk
==
static_assert
(
selected_mfma
.
group_size
*
selected_mfma
.
num_groups_per_blk
==
selected_mfma
.
num_regs_per_blk
,
selected_mfma
.
num_regs_per_blk
,
...
@@ -533,8 +533,6 @@ struct MfmaSelector
...
@@ -533,8 +533,6 @@ struct MfmaSelector
"is_k_reduction wrong!"
);
"is_k_reduction wrong!"
);
}
}
__host__
__device__
constexpr
MfmaSelector
()
{
mfma_check
();
}
static
constexpr
bool
IsABroadcast
()
static
constexpr
bool
IsABroadcast
()
{
{
static_assert
(
NPerXdlops
>=
MPerXdlops
,
"only support ABroadcast"
);
static_assert
(
NPerXdlops
>=
MPerXdlops
,
"only support ABroadcast"
);
...
@@ -621,6 +619,8 @@ struct XdlopsGemm
...
@@ -621,6 +619,8 @@ struct XdlopsGemm
return
MPerXdlops
*
NPerXdlops
/
mfma_instr
.
wave_size
;
return
MPerXdlops
*
NPerXdlops
/
mfma_instr
.
wave_size
;
}
}
__device__
static
constexpr
index_t
GetWaveSize
()
{
return
mfma_instr
.
wave_size
;
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
{
...
...
composable_kernel/include/utility/static_buffer_of_vector_type_v2.hpp
View file @
3fb903fa
...
@@ -22,6 +22,13 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N>
...
@@ -22,6 +22,13 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N>
static
constexpr
index_t
vector_size
=
GetVectorSize
();
static
constexpr
index_t
vector_size
=
GetVectorSize
();
__host__
__device__
static
constexpr
index_t
GetNumVectors
()
{
return
N
;
}
__host__
__device__
static
constexpr
index_t
GetNumElements
()
{
return
GetVectorSize
()
*
GetNumVectors
();
}
VecBaseType
invalid_element_value_
=
VecBaseType
{
0
};
VecBaseType
invalid_element_value_
=
VecBaseType
{
0
};
T
invalid_vec_value_
=
T
{
0
};
T
invalid_vec_value_
=
T
{
0
};
...
@@ -91,6 +98,12 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N>
...
@@ -91,6 +98,12 @@ struct StaticBufferOfVectorTypeV2 : public StaticallyIndexedArray<T, N>
return
GetElement
(
i
,
true
);
return
GetElement
(
i
,
true
);
}
}
__host__
__device__
void
Clear
()
{
static_for
<
0
,
GetNumElements
(),
1
>
{}(
[
&
](
auto
i
)
{
GetElement
(
i
,
true
)
=
invalid_element_value_
;
});
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
...
...
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
3fb903fa
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE
0
#define USE_DYNAMIC_MODE
1
#define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V6R1_NCHW 0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment