Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0acd3ebe
Commit
0acd3ebe
authored
Sep 03, 2021
by
ltqin
Browse files
start change gridwise k split
parent
1043ab4f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
76 additions
and
42 deletions
+76
-42
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
+35
-19
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nkhw.hpp
...ard_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nkhw.hpp
+16
-13
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
+12
-10
host/driver_offline/src/conv_wrw_driver_offline.cpp
host/driver_offline/src/conv_wrw_driver_offline.cpp
+2
-0
host/host_tensor/include/host_tensor_generator.hpp
host/host_tensor/include/host_tensor_generator.hpp
+11
-0
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
View file @
0acd3ebe
...
@@ -148,7 +148,8 @@ template <index_t BlockSize,
...
@@ -148,7 +148,8 @@ template <index_t BlockSize,
typename
CGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
>
bool
CAccessOrderMRepeatNRepeat
,
index_t
KBatch
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -233,7 +234,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -233,7 +234,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
}
}
__host__
__device__
static
constexpr
index_t
__host__
__device__
static
constexpr
index_t
Calculate
MN
GridSize
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
CalculateGridSize
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
...
@@ -243,15 +244,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -243,15 +244,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
return
grid_size_mn
;
return
grid_size_mn
;
}
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
index_t
M
,
const
index_t
N
)
{
const
index_t
grid_size_mn
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size_mn
;
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeABK0MK1GridDescriptor
(
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
,
const
index_t
kbatch
)
MakeABK0MK1GridDescriptor
(
const
AK0MK1GridDesc
&
a_k0_m_k1_grid_desc
)
{
{
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
M
=
a_k0_m_k1_grid_desc
.
GetLength
(
I1
);
assert
(
K0
%
KBatch
==
0
);
const
auto
a_b_k0_m_k1_grid_desc
=
transform_tensor_descriptor
(
const
auto
a_b_k0_m_k1_grid_desc
=
transform_tensor_descriptor
(
a_k0_m_k1_grid_desc
,
a_k0_m_k1_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
kb
atch
,
K0
/
kb
atch
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KB
atch
,
K0
/
KB
atch
)),
make_pass_through_transform
(
M
),
make_pass_through_transform
(
M
),
make_pass_through_transform
(
K1Value
)),
make_pass_through_transform
(
K1Value
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
...
@@ -260,14 +270,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -260,14 +270,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeBBK0NK1GridDescriptor
(
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
,
const
index_t
kbatch
)
MakeBBK0NK1GridDescriptor
(
const
BK0NK1GridDesc
&
b_k0_n_k1_grid_desc
)
{
{
const
auto
K0
=
b_k0_n_k1_grid_desc
.
GetLength
(
I0
);
const
auto
K0
=
b_k0_n_k1_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
b_k0_n_k1_grid_desc
.
GetLength
(
I1
);
const
auto
N
=
b_k0_n_k1_grid_desc
.
GetLength
(
I1
);
assert
(
K0
%
KBatch
==
0
);
const
auto
b_b_k0_n_k1_grid_desc
=
transform_tensor_descriptor
(
const
auto
b_b_k0_n_k1_grid_desc
=
transform_tensor_descriptor
(
b_k0_n_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
kb
atch
,
K0
/
kb
atch
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KB
atch
,
K0
/
KB
atch
)),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
),
make_pass_through_transform
(
K1Value
)),
make_pass_through_transform
(
K1Value
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
...
@@ -327,8 +339,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -327,8 +339,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
return
c_blockid_to_m0_n0_block_cluster_adaptor
;
}
}
using
ABK0MK1GridDesc
=
decltype
(
MakeABK0MK1GridDescriptor
(
AK0MK1GridDesc
{}
,
I1
));
using
ABK0MK1GridDesc
=
decltype
(
MakeABK0MK1GridDescriptor
(
AK0MK1GridDesc
{}));
using
BBK0NK1GridDesc
=
decltype
(
MakeBBK0NK1GridDescriptor
(
BK0NK1GridDesc
{}
,
I1
));
using
BBK0NK1GridDesc
=
decltype
(
MakeBBK0NK1GridDescriptor
(
BK0NK1GridDesc
{}));
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{}));
...
@@ -344,24 +356,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
...
@@ -344,24 +356,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
const
CBlockClusterAdaptor
&
c_block_cluster_adaptor
)
const
CBlockClusterAdaptor
&
c_block_cluster_adaptor
)
{
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_grid
,
a_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
p_a_grid
,
a_
b_
k0_m_k1_grid_desc
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_grid
,
b_k0_n_k1_grid_desc
.
GetElementSpaceSize
());
p_b_grid
,
b_
b_k0_n_k1_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetElementSpaceSize
());
p_c_grid
,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetElementSpaceSize
());
const
auto
K0
=
a_k0_m_k1_grid_desc
.
GetLength
(
I0
);
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
M
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I2
);
const
auto
kbatch
=
CalculateKBatch
(
CMNGridDesc
{},
b_k0_n_k1_grid_desc
);
const
auto
N
=
b_b_k0_n_k1_grid_desc
.
GetLength
(
I2
);
if
(
get_block_1d_id
()
==
0
)
const
auto
b_grid_size
=
CalculateGridSize
(
M
,
N
);
printf
(
"*****kbatch : %d, %d, %d, %d
\n
"
,
const
auto
nBatch
=
get_block_1d_id
()
/
b_grid_size
;
kbatch
,
const
auto
blockid_in_batch
=
get_block_1d_id
()
%
b_grid_size
;
a_b_k0_m_k1_grid_desc
.
GetLength
(
I0
),
if
(
get_block_1d_id
()
==
2000
)
b_b_k0_n_k1_grid_desc
.
GetLength
(
I0
),
printf
(
"grid size: %d, Batch: %d block_id: %d k0: %d
\n
"
,
b_grid_size
,
nBatch
,
blockid_in_batch
,
K0
);
K0
);
// divide block work by [M, N]
// divide block work by [M, N]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_
block
_1
d_i
d
()
));
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
block
i
d_i
n_batch
));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
const
index_t
m_block_data_idx_on_grid
=
...
...
host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nkhw.hpp
View file @
0acd3ebe
...
@@ -75,6 +75,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
...
@@ -75,6 +75,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
constexpr
index_t
KBatch
=
96
;
#elif 1
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -167,7 +169,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
...
@@ -167,7 +169,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
TInWei
,
TInWei
,
TAcc
,
TAcc
,
TOut
,
TOut
,
InMemoryDataOperationEnum_t
::
Set
,
InMemoryDataOperationEnum_t
::
AtomicAdd
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
wei_gemmm_gemmn_grid_desc
),
decltype
(
wei_gemmm_gemmn_grid_desc
),
...
@@ -203,18 +205,19 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
...
@@ -203,18 +205,19 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r3_xdlops_nchw_kcyx_nk
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
>
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
false
,
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
KBatch
>
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
static_cast
<
TInWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
in_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
wei_gemmm_gemmn_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
))
/
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
))
/
...
...
host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp
View file @
0acd3ebe
...
@@ -46,7 +46,8 @@ template <ck::index_t BlockSize,
...
@@ -46,7 +46,8 @@ template <ck::index_t BlockSize,
typename
CGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
bool
CAccessOrderMRepeatNRepeat
>
bool
CAccessOrderMRepeatNRepeat
,
ck
::
index_t
KBatch
>
__host__
float
driver_gemm_xdlops_v2r4
(
const
FloatAB
*
p_a_grid
,
__host__
float
driver_gemm_xdlops_v2r4
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
FloatC
*
p_c_grid
,
...
@@ -108,7 +109,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
...
@@ -108,7 +109,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
CGridStepHacks
,
CGridStepHacks
,
AGridMoveSliceWindowStepHacks
,
AGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
,
CAccessOrderMRepeatNRepeat
>
;
CAccessOrderMRepeatNRepeat
,
KBatch
>
;
{
{
std
::
cout
<<
"a_k0_m_k1_grid_desc{"
<<
a_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"a_k0_m_k1_grid_desc{"
<<
a_k0_m_k1_grid_desc
.
GetLength
(
I0
)
<<
", "
...
@@ -122,13 +124,11 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
...
@@ -122,13 +124,11 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
std
::
cout
<<
"c_m_n_grid_desc{ "
<<
c_m_n_grid_desc
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"c_m_n_grid_desc{ "
<<
c_m_n_grid_desc
.
GetLength
(
I0
)
<<
", "
<<
c_m_n_grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
c_m_n_grid_desc
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
}
const
auto
kbatch
=
GridwiseGemm
::
CalculateKBatch
(
c_m_n_grid_desc
,
b_k0_n_k1_grid_desc
);
// const auto kbatch = GridwiseGemm::CalculateKBatch(c_m_n_grid_desc, b_k0_n_k1_grid_desc);
const
auto
a_b_k0_m_k1_grid_desc
=
const
auto
a_b_k0_m_k1_grid_desc
=
GridwiseGemm
::
MakeABK0MK1GridDescriptor
(
a_k0_m_k1_grid_desc
);
GridwiseGemm
::
MakeABK0MK1GridDescriptor
(
a_k0_m_k1_grid_desc
,
kbatch
);
const
auto
b_b_k0_n_k1_grid_desc
=
GridwiseGemm
::
MakeBBK0NK1GridDescriptor
(
b_k0_n_k1_grid_desc
);
const
auto
b_b_k0_n_k1_grid_desc
=
GridwiseGemm
::
MakeBBK0NK1GridDescriptor
(
b_k0_n_k1_grid_desc
,
kbatch
);
{
{
std
::
cout
<<
"k batch number is: "
<<
kbatch
<<
std
::
endl
;
//
std::cout << "k batch number is: " << kbatch << std::endl;
}
}
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
))
if
(
!
GridwiseGemm
::
CheckValidity
(
a_k0_m_k1_grid_desc
,
b_k0_n_k1_grid_desc
,
c_m_n_grid_desc
))
{
{
...
@@ -147,8 +147,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
...
@@ -147,8 +147,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
using
CBlockClusterAdaptor
=
decltype
(
c_block_cluster_adaptor
);
const
index_t
grid_size_mn
=
GridwiseGemm
::
Calculate
MN
GridSize
(
c_m_n_grid_desc
);
const
index_t
grid_size_mn
=
GridwiseGemm
::
CalculateGridSize
(
c_m_n_grid_desc
);
const
index_t
grid_size
=
grid_size_mn
*
kb
atch
;
const
index_t
grid_size
=
grid_size_mn
*
KB
atch
;
{
{
std
::
cout
<<
"mxn gridSize : "
<<
grid_size_mn
<<
" finally grid_size : "
<<
grid_size
std
::
cout
<<
"mxn gridSize : "
<<
grid_size_mn
<<
" finally grid_size : "
<<
grid_size
<<
std
::
endl
;
<<
std
::
endl
;
...
@@ -189,6 +189,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
...
@@ -189,6 +189,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
a_k0_m_k1_grid_desc_dev_buf
.
ToDevice
(
&
a_k0_m_k1_grid_desc
);
a_k0_m_k1_grid_desc_dev_buf
.
ToDevice
(
&
a_k0_m_k1_grid_desc
);
b_k0_n_k1_grid_desc_dev_buf
.
ToDevice
(
&
b_k0_n_k1_grid_desc
);
b_k0_n_k1_grid_desc_dev_buf
.
ToDevice
(
&
b_k0_n_k1_grid_desc
);
a_b_k0_m_k1_grid_desc_dev_buf
.
ToDevice
(
&
a_b_k0_m_k1_grid_desc
);
b_b_k0_n_k1_grid_desc_dev_buf
.
ToDevice
(
&
b_b_k0_n_k1_grid_desc
);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
);
c_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_block_cluster_adaptor
);
c_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_block_cluster_adaptor
);
...
...
host/driver_offline/src/conv_wrw_driver_offline.cpp
View file @
0acd3ebe
...
@@ -267,6 +267,8 @@ int main(int argc, char* argv[])
...
@@ -267,6 +267,8 @@ int main(int argc, char* argv[])
{
{
throw
std
::
runtime_error
(
"wrong! layout"
);
throw
std
::
runtime_error
(
"wrong! layout"
);
}
}
// set zero to wei_device
wei_device
.
GenerateTensorValue
(
GeneratorTensor_0
{},
num_thread
);
const
auto
tmp
=
f_make_for_device_nchw
();
const
auto
tmp
=
f_make_for_device_nchw
();
...
...
host/host_tensor/include/host_tensor_generator.hpp
View file @
0acd3ebe
...
@@ -15,6 +15,17 @@ struct GeneratorTensor_1
...
@@ -15,6 +15,17 @@ struct GeneratorTensor_1
}
}
};
};
struct
GeneratorTensor_0
{
int
value
=
0
;
template
<
typename
...
Is
>
float
operator
()(
Is
...)
{
return
value
;
}
};
struct
GeneratorTensor_2
struct
GeneratorTensor_2
{
{
int
min_value
=
0
;
int
min_value
=
0
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment