Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
65c56e56
"vscode:/vscode.git/clone" did not exist on "89ed418b8366066a8d54d4c2ae98454919fbc207"
Commit
65c56e56
authored
Jul 25, 2022
by
Chao Liu
Browse files
update Tensor
parent
028171e9
Changes
60
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
74 additions
and
73 deletions
+74
-73
profiler/include/profile_batched_gemm_impl.hpp
profiler/include/profile_batched_gemm_impl.hpp
+3
-3
profiler/include/profile_batched_gemm_reduce_impl.hpp
profiler/include/profile_batched_gemm_reduce_impl.hpp
+5
-5
profiler/include/profile_conv_bwd_data_impl.hpp
profiler/include/profile_conv_bwd_data_impl.hpp
+3
-3
profiler/include/profile_conv_bwd_weight_impl.hpp
profiler/include/profile_conv_bwd_weight_impl.hpp
+4
-3
profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp
profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp
+5
-5
profiler/include/profile_conv_fwd_bias_relu_impl.hpp
profiler/include/profile_conv_fwd_bias_relu_impl.hpp
+4
-4
profiler/include/profile_conv_fwd_impl.hpp
profiler/include/profile_conv_fwd_impl.hpp
+3
-3
profiler/include/profile_gemm_add_add_fastgelu_impl.hpp
profiler/include/profile_gemm_add_add_fastgelu_impl.hpp
+5
-5
profiler/include/profile_gemm_bias_add_reduce_impl.hpp
profiler/include/profile_gemm_bias_add_reduce_impl.hpp
+7
-7
profiler/include/profile_gemm_bilinear_impl.hpp
profiler/include/profile_gemm_bilinear_impl.hpp
+4
-4
profiler/include/profile_gemm_impl.hpp
profiler/include/profile_gemm_impl.hpp
+3
-3
profiler/include/profile_gemm_reduce_impl.hpp
profiler/include/profile_gemm_reduce_impl.hpp
+5
-5
profiler/include/profile_gemm_splitk_impl.hpp
profiler/include/profile_gemm_splitk_impl.hpp
+3
-3
profiler/include/profile_grouped_gemm_impl.hpp
profiler/include/profile_grouped_gemm_impl.hpp
+3
-3
profiler/include/profile_normalization_impl.hpp
profiler/include/profile_normalization_impl.hpp
+2
-2
profiler/include/profile_reduce_impl.hpp
profiler/include/profile_reduce_impl.hpp
+3
-3
test/gemm/gemm_util.hpp
test/gemm/gemm_util.hpp
+3
-3
test/gemm_split_k/gemm_split_k.cpp
test/gemm_split_k/gemm_split_k.cpp
+3
-3
test/layernorm/test_layernorm_util.hpp
test/layernorm/test_layernorm_util.hpp
+4
-4
test/softmax/test_softmax_util.hpp
test/softmax/test_softmax_util.hpp
+2
-2
No files found.
profiler/include/profile_batched_gemm_impl.hpp
View file @
65c56e56
...
@@ -114,9 +114,9 @@ bool profile_batched_gemm_impl(int do_verification,
...
@@ -114,9 +114,9 @@ bool profile_batched_gemm_impl(int do_verification,
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
}
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_g_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_g_k_n
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_g_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_g_m_n_device_result
.
mDesc
.
GetElementSpace
Size
());
a_device_buf
.
ToDevice
(
a_g_m_k
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_g_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_g_k_n
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_g_k_n
.
mData
.
data
());
...
...
profiler/include/profile_batched_gemm_reduce_impl.hpp
View file @
65c56e56
...
@@ -193,13 +193,13 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
...
@@ -193,13 +193,13 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
}
}
}
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_g_m_k
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_g_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_g_k_n
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_g_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_g_m_n_device_result
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
reduce0_device_buf
(
sizeof
(
ReduceDataType
)
*
DeviceMem
reduce0_device_buf
(
sizeof
(
ReduceDataType
)
*
d0_g_m_device_result
.
mDesc
.
GetElementSpace
());
d0_g_m_device_result
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
reduce1_device_buf
(
sizeof
(
ReduceDataType
)
*
DeviceMem
reduce1_device_buf
(
sizeof
(
ReduceDataType
)
*
d1_g_m_device_result
.
mDesc
.
GetElementSpace
());
d1_g_m_device_result
.
mDesc
.
GetElementSpace
Size
());
std
::
array
<
void
*
,
2
>
p_reduces
=
{
reduce0_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
void
*
,
2
>
p_reduces
=
{
reduce0_device_buf
.
GetDeviceBuffer
(),
reduce1_device_buf
.
GetDeviceBuffer
()};
reduce1_device_buf
.
GetDeviceBuffer
()};
...
...
profiler/include/profile_conv_bwd_data_impl.hpp
View file @
65c56e56
...
@@ -93,9 +93,9 @@ bool profile_conv_bwd_data_impl(int do_verification,
...
@@ -93,9 +93,9 @@ bool profile_conv_bwd_data_impl(int do_verification,
weight
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
weight
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
}
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
input_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
input_device_result
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
weight
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
weight
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
output
.
mDesc
.
GetElementSpace
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
output
.
mDesc
.
GetElementSpace
Size
());
out_device_buf
.
ToDevice
(
output
.
mData
.
data
());
out_device_buf
.
ToDevice
(
output
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
weight
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
weight
.
mData
.
data
());
...
...
profiler/include/profile_conv_bwd_weight_impl.hpp
View file @
65c56e56
...
@@ -99,9 +99,10 @@ bool profile_conv_bwd_weight_impl(int do_verification,
...
@@ -99,9 +99,10 @@ bool profile_conv_bwd_weight_impl(int do_verification,
output
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutDataType
>
{
-
0.5
,
0.5
});
output
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutDataType
>
{
-
0.5
,
0.5
});
}
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
input
.
mDesc
.
GetElementSpace
());
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
input
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
weight_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
output
.
mDesc
.
GetElementSpace
());
weight_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
output
.
mDesc
.
GetElementSpaceSize
());
in_device_buf
.
ToDevice
(
input
.
mData
.
data
());
in_device_buf
.
ToDevice
(
input
.
mData
.
data
());
out_device_buf
.
ToDevice
(
output
.
mData
.
data
());
out_device_buf
.
ToDevice
(
output
.
mData
.
data
());
...
...
profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp
View file @
65c56e56
...
@@ -157,12 +157,12 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
...
@@ -157,12 +157,12 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
}
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_n_k_ho_wo_device_result
.
mDesc
.
GetElementSpace
());
out_n_k_ho_wo_device_result
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
bias_device_buf
(
sizeof
(
OutDataType
)
*
bias_k
.
mDesc
.
GetElementSpace
());
DeviceMem
bias_device_buf
(
sizeof
(
OutDataType
)
*
bias_k
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
resi_device_buf
(
sizeof
(
OutDataType
)
*
resi_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
DeviceMem
resi_device_buf
(
sizeof
(
OutDataType
)
*
resi_n_k_ho_wo
.
mDesc
.
GetElementSpace
Size
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
...
...
profiler/include/profile_conv_fwd_bias_relu_impl.hpp
View file @
65c56e56
...
@@ -149,11 +149,11 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
...
@@ -149,11 +149,11 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
}
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_n_k_ho_wo_device_result
.
mDesc
.
GetElementSpace
());
out_n_k_ho_wo_device_result
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
bias_device_buf
(
sizeof
(
OutDataType
)
*
bias_k
.
mDesc
.
GetElementSpace
());
DeviceMem
bias_device_buf
(
sizeof
(
OutDataType
)
*
bias_k
.
mDesc
.
GetElementSpace
Size
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
...
...
profiler/include/profile_conv_fwd_impl.hpp
View file @
65c56e56
...
@@ -71,9 +71,9 @@ bool profile_conv_fwd_impl(int do_verification,
...
@@ -71,9 +71,9 @@ bool profile_conv_fwd_impl(int do_verification,
const
auto
wei_element_op
=
WeiElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
const
auto
out_element_op
=
OutElementOp
{};
const
auto
out_element_op
=
OutElementOp
{};
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
input
.
mDesc
.
GetElementSpace
());
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
input
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
weight
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
weight
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
device_output
.
mDesc
.
GetElementSpace
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
device_output
.
mDesc
.
GetElementSpace
Size
());
in_device_buf
.
ToDevice
(
input
.
mData
.
data
());
in_device_buf
.
ToDevice
(
input
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
weight
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
weight
.
mData
.
data
());
...
...
profiler/include/profile_gemm_add_add_fastgelu_impl.hpp
View file @
65c56e56
...
@@ -146,11 +146,11 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
...
@@ -146,11 +146,11 @@ bool profile_gemm_add_add_fastgelu_impl(int do_verification,
}
}
}
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
d0_m_n_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d0_m_n_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
d1_m_n_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d1_m_n_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpace
Size
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
...
...
profiler/include/profile_gemm_bias_add_reduce_impl.hpp
View file @
65c56e56
...
@@ -217,15 +217,15 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
...
@@ -217,15 +217,15 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
}
}
}
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
bias_device_buf
(
sizeof
(
BiasDataType
)
*
bias_n
.
mDesc
.
GetElementSpace
());
DeviceMem
bias_device_buf
(
sizeof
(
BiasDataType
)
*
bias_n
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
reduce0_device_buf
(
sizeof
(
ReduceDataType
)
*
DeviceMem
reduce0_device_buf
(
sizeof
(
ReduceDataType
)
*
reduce0_m_device_result
.
mDesc
.
GetElementSpace
());
reduce0_m_device_result
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
reduce1_device_buf
(
sizeof
(
ReduceDataType
)
*
DeviceMem
reduce1_device_buf
(
sizeof
(
ReduceDataType
)
*
reduce1_m_device_result
.
mDesc
.
GetElementSpace
());
reduce1_m_device_result
.
mDesc
.
GetElementSpace
Size
());
std
::
array
<
void
*
,
2
>
p_reduces
=
{
reduce0_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
void
*
,
2
>
p_reduces
=
{
reduce0_device_buf
.
GetDeviceBuffer
(),
reduce1_device_buf
.
GetDeviceBuffer
()};
reduce1_device_buf
.
GetDeviceBuffer
()};
...
...
profiler/include/profile_gemm_bilinear_impl.hpp
View file @
65c56e56
...
@@ -142,10 +142,10 @@ bool profile_gemm_bilinear_impl(int do_verification,
...
@@ -142,10 +142,10 @@ bool profile_gemm_bilinear_impl(int do_verification,
}
}
}
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
d_m_n_device_buf
(
sizeof
(
DDataType
)
*
d_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
d_m_n_device_buf
(
sizeof
(
DDataType
)
*
d_m_n
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpace
Size
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
...
...
profiler/include/profile_gemm_impl.hpp
View file @
65c56e56
...
@@ -86,9 +86,9 @@ int profile_gemm_impl(int do_verification,
...
@@ -86,9 +86,9 @@ int profile_gemm_impl(int do_verification,
const
auto
b_element_op
=
BElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
Size
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
...
...
profiler/include/profile_gemm_reduce_impl.hpp
View file @
65c56e56
...
@@ -189,13 +189,13 @@ bool profile_gemm_reduce_impl(int do_verification,
...
@@ -189,13 +189,13 @@ bool profile_gemm_reduce_impl(int do_verification,
}
}
}
}
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
reduce0_device_buf
(
sizeof
(
ReduceDataType
)
*
DeviceMem
reduce0_device_buf
(
sizeof
(
ReduceDataType
)
*
reduce0_m_device_result
.
mDesc
.
GetElementSpace
());
reduce0_m_device_result
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
reduce1_device_buf
(
sizeof
(
ReduceDataType
)
*
DeviceMem
reduce1_device_buf
(
sizeof
(
ReduceDataType
)
*
reduce1_m_device_result
.
mDesc
.
GetElementSpace
());
reduce1_m_device_result
.
mDesc
.
GetElementSpace
Size
());
std
::
array
<
void
*
,
2
>
p_reduces
=
{
reduce0_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
void
*
,
2
>
p_reduces
=
{
reduce0_device_buf
.
GetDeviceBuffer
(),
reduce1_device_buf
.
GetDeviceBuffer
()};
reduce1_device_buf
.
GetDeviceBuffer
()};
...
...
profiler/include/profile_gemm_splitk_impl.hpp
View file @
65c56e56
...
@@ -87,9 +87,9 @@ bool profile_gemm_splitk_impl(int do_verification,
...
@@ -87,9 +87,9 @@ bool profile_gemm_splitk_impl(int do_verification,
const
auto
b_element_op
=
BElementOp
{};
const
auto
b_element_op
=
BElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
const
auto
c_element_op
=
CElementOp
{};
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
Size
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
...
...
profiler/include/profile_grouped_gemm_impl.hpp
View file @
65c56e56
...
@@ -152,12 +152,12 @@ void profile_grouped_gemm_impl(int do_verification,
...
@@ -152,12 +152,12 @@ void profile_grouped_gemm_impl(int do_verification,
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
{
{
a_device_buf
.
emplace_back
(
a_device_buf
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_m_k
[
i
].
mDesc
.
GetElementSpace
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
ADataType
)
*
a_m_k
[
i
].
mDesc
.
GetElementSpace
Size
()));
b_device_buf
.
emplace_back
(
b_device_buf
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_k_n
[
i
].
mDesc
.
GetElementSpace
()));
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
BDataType
)
*
b_k_n
[
i
].
mDesc
.
GetElementSpace
Size
()));
c_device_buf
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
c_device_buf
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
sizeof
(
CDataType
)
*
c_m_n_device_results
[
i
].
mDesc
.
GetElementSpace
()));
sizeof
(
CDataType
)
*
c_m_n_device_results
[
i
].
mDesc
.
GetElementSpace
Size
()));
a_device_buf
[
i
]
->
ToDevice
(
a_m_k
[
i
].
mData
.
data
());
a_device_buf
[
i
]
->
ToDevice
(
a_m_k
[
i
].
mData
.
data
());
b_device_buf
[
i
]
->
ToDevice
(
b_k_n
[
i
].
mData
.
data
());
b_device_buf
[
i
]
->
ToDevice
(
b_k_n
[
i
].
mData
.
data
());
...
...
profiler/include/profile_normalization_impl.hpp
View file @
65c56e56
...
@@ -92,8 +92,8 @@ void profile_normalization_impl(int do_verification,
...
@@ -92,8 +92,8 @@ void profile_normalization_impl(int do_verification,
Tensor
<
OutDataType
>
out_ref
(
out
);
Tensor
<
OutDataType
>
out_ref
(
out
);
DeviceMem
in_dev
(
sizeof
(
InDataType
)
*
in
.
mDesc
.
GetElementSpace
());
DeviceMem
in_dev
(
sizeof
(
InDataType
)
*
in
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
out_dev
(
sizeof
(
OutDataType
)
*
out
.
mDesc
.
GetElementSpace
());
DeviceMem
out_dev
(
sizeof
(
OutDataType
)
*
out
.
mDesc
.
GetElementSpace
Size
());
in_dev
.
ToDevice
(
in
.
mData
.
data
());
in_dev
.
ToDevice
(
in
.
mData
.
data
());
out_dev
.
ToDevice
(
out
.
mData
.
data
());
out_dev
.
ToDevice
(
out
.
mData
.
data
());
...
...
profiler/include/profile_reduce_impl.hpp
View file @
65c56e56
...
@@ -245,13 +245,13 @@ bool profile_reduce_impl_impl(bool do_verification,
...
@@ -245,13 +245,13 @@ bool profile_reduce_impl_impl(bool do_verification,
}
}
if
(
beta
!=
0.0
f
)
if
(
beta
!=
0.0
f
)
for
(
size_t
i
=
0
;
i
<
out_ref
.
mDesc
.
GetElementSpace
();
i
++
)
for
(
size_t
i
=
0
;
i
<
out_ref
.
mDesc
.
GetElementSpace
Size
();
i
++
)
out
.
mData
[
i
]
=
out_ref
.
mData
[
i
];
out
.
mData
[
i
]
=
out_ref
.
mData
[
i
];
};
};
// these buffers are usually provided by the user application
// these buffers are usually provided by the user application
DeviceMem
in_dev
(
sizeof
(
InDataType
)
*
in
.
mDesc
.
GetElementSpace
());
DeviceMem
in_dev
(
sizeof
(
InDataType
)
*
in
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
out_dev
(
sizeof
(
OutDataType
)
*
out
.
mDesc
.
GetElementSpace
());
DeviceMem
out_dev
(
sizeof
(
OutDataType
)
*
out
.
mDesc
.
GetElementSpace
Size
());
in_dev
.
ToDevice
(
in
.
mData
.
data
());
in_dev
.
ToDevice
(
in
.
mData
.
data
());
...
...
test/gemm/gemm_util.hpp
View file @
65c56e56
...
@@ -71,9 +71,9 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
...
@@ -71,9 +71,9 @@ bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
{
{
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpace
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpace
Size
());
auto
invoker_ptr
=
gemmPtr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
gemmPtr
->
MakeInvokerPointer
();
auto
argument_ptr
=
auto
argument_ptr
=
...
...
test/gemm_split_k/gemm_split_k.cpp
View file @
65c56e56
...
@@ -127,9 +127,9 @@ int test_gemm(const gemmArgs& args)
...
@@ -127,9 +127,9 @@ int test_gemm(const gemmArgs& args)
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
DeviceMem
a_device_buf
(
sizeof
(
float
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_device_buf
(
sizeof
(
float
)
*
a_m_k
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
b_device_buf
(
sizeof
(
float
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
float
)
*
b_k_n
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
c_device_buf
(
sizeof
(
float
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
float
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
Size
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
...
...
test/layernorm/test_layernorm_util.hpp
View file @
65c56e56
...
@@ -102,10 +102,10 @@ class TestLayernorm : public ::testing::Test
...
@@ -102,10 +102,10 @@ class TestLayernorm : public ::testing::Test
gamma
.
GenerateTensorValue
(
GeneratorTensor_3
<
GammaDataType
>
{
0.0
,
1.0
});
gamma
.
GenerateTensorValue
(
GeneratorTensor_3
<
GammaDataType
>
{
0.0
,
1.0
});
beta
.
GenerateTensorValue
(
GeneratorTensor_3
<
BetaDataType
>
{
0.0
,
1.0
});
beta
.
GenerateTensorValue
(
GeneratorTensor_3
<
BetaDataType
>
{
0.0
,
1.0
});
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpace
());
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
gamma_dev
(
sizeof
(
GammaDataType
)
*
gamma
.
mDesc
.
GetElementSpace
());
DeviceMem
gamma_dev
(
sizeof
(
GammaDataType
)
*
gamma
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
beta_dev
(
sizeof
(
BetaDataType
)
*
beta
.
mDesc
.
GetElementSpace
());
DeviceMem
beta_dev
(
sizeof
(
BetaDataType
)
*
beta
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
y_dev
(
sizeof
(
YDataType
)
*
y
.
mDesc
.
GetElementSpace
());
DeviceMem
y_dev
(
sizeof
(
YDataType
)
*
y
.
mDesc
.
GetElementSpace
Size
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
gamma_dev
.
ToDevice
(
gamma
.
mData
.
data
());
gamma_dev
.
ToDevice
(
gamma
.
mData
.
data
());
...
...
test/softmax/test_softmax_util.hpp
View file @
65c56e56
...
@@ -80,8 +80,8 @@ class TestSoftmax : public ::testing::Test
...
@@ -80,8 +80,8 @@ class TestSoftmax : public ::testing::Test
Tensor
<
OutDataType
>
out_ref
(
out
);
Tensor
<
OutDataType
>
out_ref
(
out
);
DeviceMem
in_dev
(
sizeof
(
InDataType
)
*
in
.
mDesc
.
GetElementSpace
());
DeviceMem
in_dev
(
sizeof
(
InDataType
)
*
in
.
mDesc
.
GetElementSpace
Size
());
DeviceMem
out_dev
(
sizeof
(
OutDataType
)
*
out
.
mDesc
.
GetElementSpace
());
DeviceMem
out_dev
(
sizeof
(
OutDataType
)
*
out
.
mDesc
.
GetElementSpace
Size
());
in_dev
.
ToDevice
(
in
.
mData
.
data
());
in_dev
.
ToDevice
(
in
.
mData
.
data
());
out_dev
.
ToDevice
(
out
.
mData
.
data
());
out_dev
.
ToDevice
(
out
.
mData
.
data
());
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment