Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d798c9b8
"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "7c35e13337a611258e0403298a268224816881f8"
Commit
d798c9b8
authored
Dec 02, 2021
by
Jing Zhang
Browse files
fixed c_buffer alloc
parent
4041850f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
10 additions
and
10 deletions
+10
-10
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+3
-2
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
+3
-4
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+3
-3
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+1
-1
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
d798c9b8
...
@@ -38,7 +38,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -38,7 +38,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
StaticBufferOfVectorTypeV2
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBufferOfVectorTypeV2
<
AddressSpaceEnum_t
::
Vgpr
,
vector_type
<
FloatAcc
,
16
>
,
vector_type
<
FloatAcc
,
xdlops_gemm
.
GetRegSizePerXdlops
()
>
,
MRepeat
*
NRepeat
,
MRepeat
*
NRepeat
,
true
>
true
>
c_thread_buf_
;
c_thread_buf_
;
...
@@ -136,7 +136,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -136,7 +136,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp
View file @
d798c9b8
...
@@ -557,6 +557,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -557,6 +557,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// output: register to global memory
// output: register to global memory
{
{
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
...
@@ -569,10 +572,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -569,10 +572,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
M0
>
{},
Number
<
N0
>
{},
I1
,
I1
,
Number
<
M2
>
{},
I1
,
Number
<
M4
>
{},
I1
));
// calculate origin of thread output tensor on global memory
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
const
auto
c_thread_mtx_on_block
=
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
d798c9b8
...
@@ -507,7 +507,7 @@ struct MfmaSelector
...
@@ -507,7 +507,7 @@ struct MfmaSelector
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
__host__
__device__
static
constexpr
void
mfma_check
()
__host__
__device__
constexpr
MfmaSelector
()
{
{
static_assert
(
selected_mfma
.
group_size
*
selected_mfma
.
num_groups_per_blk
==
static_assert
(
selected_mfma
.
group_size
*
selected_mfma
.
num_groups_per_blk
==
selected_mfma
.
num_regs_per_blk
,
selected_mfma
.
num_regs_per_blk
,
...
@@ -533,8 +533,6 @@ struct MfmaSelector
...
@@ -533,8 +533,6 @@ struct MfmaSelector
"is_k_reduction wrong!"
);
"is_k_reduction wrong!"
);
}
}
__host__
__device__
constexpr
MfmaSelector
()
{
mfma_check
();
}
static
constexpr
bool
IsABroadcast
()
static
constexpr
bool
IsABroadcast
()
{
{
static_assert
(
NPerXdlops
>=
MPerXdlops
,
"only support ABroadcast"
);
static_assert
(
NPerXdlops
>=
MPerXdlops
,
"only support ABroadcast"
);
...
@@ -621,6 +619,8 @@ struct XdlopsGemm
...
@@ -621,6 +619,8 @@ struct XdlopsGemm
return
MPerXdlops
*
NPerXdlops
/
mfma_instr
.
wave_size
;
return
MPerXdlops
*
NPerXdlops
/
mfma_instr
.
wave_size
;
}
}
__device__
static
constexpr
index_t
GetWaveSize
()
{
return
mfma_instr
.
wave_size
;
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
{
...
...
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
d798c9b8
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE
0
#define USE_DYNAMIC_MODE
1
#define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V6R1_NCHW 0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment