Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f9c478e2
Commit
f9c478e2
authored
May 30, 2022
by
ltqin
Browse files
Merge branch 'develop' into bmatrix_skip_lds
parents
7d85d04a
91d8b7d6
Changes
347
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
684 additions
and
295 deletions
+684
-295
include/ck/utility/tuple.hpp
include/ck/utility/tuple.hpp
+5
-6
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+3
-0
library/include/ck/library/host/host_interface.hpp
library/include/ck/library/host/host_interface.hpp
+54
-0
library/include/ck/library/host_tensor/device.hpp
library/include/ck/library/host_tensor/device.hpp
+75
-36
library/include/ck/library/host_tensor/host_common_util.hpp
library/include/ck/library/host_tensor/host_common_util.hpp
+102
-0
library/include/ck/library/host_tensor/host_reduce_util.hpp
library/include/ck/library/host_tensor/host_reduce_util.hpp
+7
-19
library/include/ck/library/host_tensor/host_reduction.hpp
library/include/ck/library/host_tensor/host_reduction.hpp
+8
-7
library/include/ck/library/host_tensor/host_tensor.hpp
library/include/ck/library/host_tensor/host_tensor.hpp
+2
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
...reference_tensor_operation/cpu/reference_batched_gemm.hpp
+2
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp
...e_tensor_operation/cpu/reference_conv_backward_weight.hpp
+171
-42
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
...eference_tensor_operation/cpu/reference_conv_bwd_data.hpp
+59
-31
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+57
-31
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp
...nsor_operation/cpu/reference_conv_fwd_bias_activation.hpp
+18
-10
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp
..._operation/cpu/reference_conv_fwd_bias_activation_add.hpp
+18
-10
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+10
-11
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp
...reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp
+2
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp
...e_tensor_operation/cpu/reference_gemm_bias_activation.hpp
+2
-1
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp
...nsor_operation/cpu/reference_gemm_bias_activation_add.hpp
+2
-1
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp
..._operation_instance/gpu/reduce/device_reduce_instance.hpp
+1
-16
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp
..._instance/gpu/reduce/device_reduce_instance_blockwise.hpp
+86
-70
No files found.
include/ck/utility/tuple.hpp
View file @
f9c478e2
...
@@ -21,9 +21,9 @@ struct TupleElement
...
@@ -21,9 +21,9 @@ struct TupleElement
{
{
__host__
__device__
constexpr
TupleElement
()
=
default
;
__host__
__device__
constexpr
TupleElement
()
=
default
;
template
<
typename
T
,
template
<
typename
enable_if
<!
is_same
<
remove_reference_t
<
remove_cv_t
<
T
>
>
,
TupleElement
>::
value
,
typename
T
,
bool
>::
type
=
false
>
typename
enable_if
<!
is_same
<
remove_cvref_t
<
T
>,
TupleElement
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleElement
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
__host__
__device__
constexpr
TupleElement
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
{
{
}
}
...
@@ -60,7 +60,7 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
...
@@ -60,7 +60,7 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
template
<
typename
Y
,
template
<
typename
Y
,
typename
enable_if
<
sizeof
...(
Is
)
==
1
&&
sizeof
...(
Xs
)
==
1
&&
typename
enable_if
<
sizeof
...(
Is
)
==
1
&&
sizeof
...(
Xs
)
==
1
&&
!
is_same
<
remove_ref
erence_t
<
remove_cv
_t
<
Y
>
>
,
TupleImpl
>::
value
,
!
is_same
<
remove_
cv
ref_t
<
Y
>,
TupleImpl
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Y
&&
y
)
__host__
__device__
constexpr
TupleImpl
(
Y
&&
y
)
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Y
>
(
y
))...
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Y
>
(
y
))...
...
@@ -101,8 +101,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
...
@@ -101,8 +101,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__
__device__
constexpr
Tuple
()
=
default
;
__host__
__device__
constexpr
Tuple
()
=
default
;
template
<
typename
Y
,
template
<
typename
Y
,
typename
enable_if
<
sizeof
...(
Xs
)
==
1
&&
typename
enable_if
<
sizeof
...(
Xs
)
==
1
&&
!
is_same
<
remove_cvref_t
<
Y
>,
Tuple
>::
value
,
!
is_same
<
remove_reference_t
<
remove_cv_t
<
Y
>
>
,
Tuple
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
__host__
__device__
constexpr
Tuple
(
Y
&&
y
)
:
base
(
std
::
forward
<
Y
>
(
y
))
__host__
__device__
constexpr
Tuple
(
Y
&&
y
)
:
base
(
std
::
forward
<
Y
>
(
y
))
{
{
...
...
include/ck/utility/type.hpp
View file @
f9c478e2
...
@@ -29,6 +29,9 @@ using remove_cv_t = typename std::remove_cv<T>::type;
...
@@ -29,6 +29,9 @@ using remove_cv_t = typename std::remove_cv<T>::type;
template
<
typename
T
>
template
<
typename
T
>
using
remove_cvref_t
=
remove_cv_t
<
std
::
remove_reference_t
<
T
>>
;
using
remove_cvref_t
=
remove_cv_t
<
std
::
remove_reference_t
<
T
>>
;
template
<
typename
T
>
using
remove_pointer_t
=
typename
std
::
remove_pointer
<
T
>::
type
;
template
<
typename
T
>
template
<
typename
T
>
inline
constexpr
bool
is_pointer_v
=
std
::
is_pointer
<
T
>::
value
;
inline
constexpr
bool
is_pointer_v
=
std
::
is_pointer
<
T
>::
value
;
...
...
library/include/ck/library/host/host_interface.hpp
0 → 100644
View file @
f9c478e2
#pragma once
#include <memory>
#include <string>
#include "stream_config.hpp"
#include "config.hpp"
#include "device_base.hpp"
struct
DeviceConvFwdPtr_t
{
using
BaseArgument
=
ck
::
tensor_operation
::
device
::
BaseArgument
;
using
BaseInvoker
=
ck
::
tensor_operation
::
device
::
BaseInvoker
;
struct
DeviceConvFwdPtrImpl
;
std
::
unique_ptr
<
DeviceConvFwdPtrImpl
>
pImpl
;
DeviceConvFwdPtr_t
();
~
DeviceConvFwdPtr_t
();
DeviceConvFwdPtr_t
(
DeviceConvFwdPtr_t
&&
);
DeviceConvFwdPtr_t
(
DeviceConvFwdPtrImpl
&
);
DeviceConvFwdPtr_t
&
operator
=
(
DeviceConvFwdPtr_t
&
)
=
delete
;
DeviceConvFwdPtr_t
&
operator
=
(
const
DeviceConvFwdPtr_t
&
)
=
delete
;
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
in_ptr
,
void
*
wei_ptr
,
void
*
out_ptr
,
size_t
N
,
size_t
K
,
size_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
const
;
// in,wei and out element ops are ignored for now since even if we change them, they
// cant be linked
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
const
;
// requires including BaseInvoker headers
std
::
string
GetTypeString
();
bool
IsSupportedArgument
(
const
BaseArgument
*
arg_ptr
);
};
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
library/include/ck/library/host_tensor/device.hpp
View file @
f9c478e2
#ifndef DEVICE_HPP
#pragma once
#define DEVICE_HPP
#include <memory>
#include <memory>
#include <functional>
#include <functional>
#include <thread>
#include <thread>
#include <chrono>
#include <chrono>
#include "hip/hip_runtime.h"
#include <hip/hip_runtime.h>
#include "hip/hip_fp16.h"
#include <hip/hip_fp16.h>
#include "stream_config.hpp"
#include "ck/options.hpp"
template
<
typename
T
>
__global__
void
set_buffer_value
(
T
*
p
,
T
x
,
uint64_t
buffer_element_size
)
{
for
(
uint64_t
i
=
threadIdx
.
x
;
i
<
buffer_element_size
;
i
+=
blockDim
.
x
)
{
p
[
i
]
=
x
;
}
}
inline
void
hip_check_error
(
hipError_t
x
)
{
if
(
x
!=
hipSuccess
)
{
std
::
ostringstream
ss
;
ss
<<
"HIP runtime error: "
<<
hipGetErrorString
(
x
)
<<
". "
<<
__FILE__
<<
": "
<<
__LINE__
<<
"in function: "
<<
__func__
;
throw
std
::
runtime_error
(
ss
.
str
());
}
}
struct
DeviceMem
struct
DeviceMem
{
{
...
@@ -17,6 +39,16 @@ struct DeviceMem
...
@@ -17,6 +39,16 @@ struct DeviceMem
void
ToDevice
(
const
void
*
p
);
void
ToDevice
(
const
void
*
p
);
void
FromDevice
(
void
*
p
);
void
FromDevice
(
void
*
p
);
void
SetZero
();
void
SetZero
();
template
<
typename
T
>
void
SetValue
(
T
x
)
{
if
(
mMemSize
%
sizeof
(
T
)
!=
0
)
{
throw
std
::
runtime_error
(
"wrong! not entire DeviceMem will be set"
);
}
set_buffer_value
<
T
><<<
1
,
1024
>>>
(
static_cast
<
T
*>
(
mpDeviceBuf
),
x
,
mMemSize
/
sizeof
(
T
));
}
~
DeviceMem
();
~
DeviceMem
();
void
*
mpDeviceBuf
;
void
*
mpDeviceBuf
;
...
@@ -36,49 +68,56 @@ struct KernelTimer
...
@@ -36,49 +68,56 @@ struct KernelTimer
std
::
unique_ptr
<
KernelTimerImpl
>
impl
;
std
::
unique_ptr
<
KernelTimerImpl
>
impl
;
};
};
using
device_stream_t
=
hipStream_t
;
template
<
typename
...
Args
,
typename
F
>
template
<
typename
...
Args
,
typename
F
>
void
launch_kernel
(
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
float
launch_and_time_kernel
(
const
StreamConfig
&
stream_config
,
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
{
{
hipStream_t
stream_id
=
nullptr
;
#if CK_TIME_KERNEL
if
(
stream_config
.
time_kernel_
)
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...);
{
}
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
__func__
,
grid_dim
.
x
,
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
template
<
typename
...
Args
,
typename
F
>
const
int
nrepeat
=
10
;
float
launch_and_time_kernel
(
F
kernel
,
int
nrepeat
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
{
KernelTimer
timer
;
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
printf
(
"Warm up 1 time
\n
"
);
__func__
,
grid_dim
.
x
,
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
printf
(
"Warm up
\n
"
);
// warm up
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
hipStream_t
stream_id
=
nullptr
;
printf
(
"Start running %d times...
\n
"
,
nrepeat
)
;
// warm up
KernelTimer
timer
;
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...
);
timer
.
Start
(
);
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
}
timer
.
Start
();
timer
.
End
();
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
return
timer
.
GetElapsedTime
()
/
nrepeat
;
{
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...);
}
}
else
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
timer
.
End
();
return
0
;
}
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
return
timer
.
GetElapsedTime
()
/
nrepeat
;
return
0
;
}
#endif
#endif
}
library/include/ck/library/host_tensor/host_common_util.hpp
0 → 100644
View file @
f9c478e2
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GUARD_HOST_COMMON_UTIL_HPP
#define GUARD_HOST_COMMON_UTIL_HPP
#include <vector>
#include <iostream>
#include <fstream>
#include <string>
#include "config.hpp"
namespace
ck
{
namespace
host_common
{
template
<
typename
T
>
static
inline
void
dumpBufferToFile
(
const
char
*
fileName
,
T
*
data
,
size_t
dataNumItems
)
{
std
::
ofstream
outFile
(
fileName
,
std
::
ios
::
binary
);
if
(
outFile
)
{
outFile
.
write
(
reinterpret_cast
<
char
*>
(
data
),
dataNumItems
*
sizeof
(
T
));
outFile
.
close
();
std
::
cout
<<
"Write output to file "
<<
fileName
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"Could not open file "
<<
fileName
<<
" for writing"
<<
std
::
endl
;
}
};
template
<
typename
T
>
static
inline
T
getSingleValueFromString
(
const
std
::
string
&
valueStr
)
{
std
::
istringstream
iss
(
valueStr
);
T
val
;
iss
>>
val
;
return
(
val
);
};
template
<
typename
T
>
static
inline
std
::
vector
<
T
>
getTypeValuesFromString
(
const
char
*
cstr_values
)
{
std
::
string
valuesStr
(
cstr_values
);
std
::
vector
<
T
>
values
;
std
::
size_t
pos
=
0
;
std
::
size_t
new_pos
;
new_pos
=
valuesStr
.
find
(
','
,
pos
);
while
(
new_pos
!=
std
::
string
::
npos
)
{
const
std
::
string
sliceStr
=
valuesStr
.
substr
(
pos
,
new_pos
-
pos
);
T
val
=
getSingleValueFromString
<
T
>
(
sliceStr
);
values
.
push_back
(
val
);
pos
=
new_pos
+
1
;
new_pos
=
valuesStr
.
find
(
','
,
pos
);
};
std
::
string
sliceStr
=
valuesStr
.
substr
(
pos
);
T
val
=
getSingleValueFromString
<
T
>
(
sliceStr
);
values
.
push_back
(
val
);
return
(
values
);
}
};
// namespace host_common
};
// namespace ck
#endif
library/include/ck/library/host_tensor/host_reduce_util.hpp
View file @
f9c478e2
...
@@ -28,9 +28,7 @@
...
@@ -28,9 +28,7 @@
#include <limits>
#include <limits>
#include <cmath>
#include <cmath>
#include <cassert>
#include <functional>
#include <stdexcept>
#include <string>
#include "reduction_enums.hpp"
#include "reduction_enums.hpp"
#include "data_type.hpp"
#include "data_type.hpp"
...
@@ -214,13 +212,13 @@ binop_with_nan_check(std::function<void(AccDataType&, AccDataType)> opReduce,
...
@@ -214,13 +212,13 @@ binop_with_nan_check(std::function<void(AccDataType&, AccDataType)> opReduce,
};
};
};
};
template
<
typename
AccDataType
,
bool
PropagateNan
>
template
<
typename
AccDataType
,
typename
IndexDataType
,
bool
PropagateNan
>
__host__
static
inline
void
__host__
static
inline
void
binop_with_nan_check
2
(
std
::
function
<
void
(
AccDataType
&
,
AccDataType
,
bool
&
)
>
opReduce
,
binop_with_
index_and_
nan_check
(
std
::
function
<
void
(
AccDataType
&
,
AccDataType
,
bool
&
)
>
opReduce
,
AccDataType
&
accuVal
,
AccDataType
&
accuVal
,
AccDataType
currVal
,
AccDataType
currVal
,
int
&
accuIndex
,
IndexDataType
&
accuIndex
,
int
currIndex
)
IndexDataType
currIndex
)
{
{
using
ck
::
math
::
isnan
;
using
ck
::
math
::
isnan
;
...
@@ -254,16 +252,6 @@ binop_with_nan_check2(std::function<void(AccDataType&, AccDataType, bool&)> opRe
...
@@ -254,16 +252,6 @@ binop_with_nan_check2(std::function<void(AccDataType&, AccDataType, bool&)> opRe
};
// namespace host_reduce
};
// namespace host_reduce
static
inline
std
::
vector
<
int
>
to_int_vector
(
const
std
::
vector
<
size_t
>&
inData
)
{
std
::
vector
<
int
>
outData
;
for
(
auto
elem
:
inData
)
outData
.
push_back
(
static_cast
<
int
>
(
elem
));
return
(
outData
);
};
};
// namespace ck
};
// namespace ck
#endif
#endif
library/include/ck/library/host_tensor/host_reduction.hpp
View file @
f9c478e2
...
@@ -34,6 +34,7 @@
...
@@ -34,6 +34,7 @@
#include "reduction_enums.hpp"
#include "reduction_enums.hpp"
#include "reduction_common.hpp"
#include "reduction_common.hpp"
#include "host_reduce_util.hpp"
#include "host_reduce_util.hpp"
#include "host_common_util.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "data_type.hpp"
#include "data_type.hpp"
...
@@ -200,7 +201,7 @@ struct ReductionHost
...
@@ -200,7 +201,7 @@ struct ReductionHost
using
ck
::
float_equal_one
;
using
ck
::
float_equal_one
;
using
ck
::
float_equal_zero
;
using
ck
::
float_equal_zero
;
using
ck
::
type_convert
;
using
ck
::
type_convert
;
using
ck
::
host_reduce
::
binop_with_nan_check
2
;
using
ck
::
host_reduce
::
binop_with_
index_and_
nan_check
;
using
ck
::
host_reduce
::
ReduceOpFn2
;
using
ck
::
host_reduce
::
ReduceOpFn2
;
using
ck
::
host_reduce
::
ReduceOpZeroVal
;
using
ck
::
host_reduce
::
ReduceOpZeroVal
;
...
@@ -211,7 +212,7 @@ struct ReductionHost
...
@@ -211,7 +212,7 @@ struct ReductionHost
AccDataType
accuVal
=
ReduceOpZeroVal
<
AccDataType
,
ReduceOpId
>
();
AccDataType
accuVal
=
ReduceOpZeroVal
<
AccDataType
,
ReduceOpId
>
();
IndexDataType
accuIndex
=
0
;
IndexDataType
accuIndex
=
0
;
for
(
IndexDataType
i
=
0
;
i
<
reduce_dim_indexes
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
reduce_dim_indexes
.
size
();
i
++
)
{
{
auto
offset_reduce
=
auto
offset_reduce
=
get_offset_from_index
<
NumReduceDim
>
(
reduceStrides
,
reduce_dim_indexes
[
i
]);
get_offset_from_index
<
NumReduceDim
>
(
reduceStrides
,
reduce_dim_indexes
[
i
]);
...
@@ -220,9 +221,9 @@ struct ReductionHost
...
@@ -220,9 +221,9 @@ struct ReductionHost
preUnaryOp
(
currVal
);
preUnaryOp
(
currVal
);
auto
currIndex
=
i
;
auto
currIndex
=
static_cast
<
IndexDataType
>
(
i
)
;
binop_with_nan_check
2
<
AccDataType
,
PropagateNan
>
(
binop_with_
index_and_
nan_check
<
AccDataType
,
IndexDataType
,
PropagateNan
>
(
opReduce2
,
accuVal
,
currVal
,
accuIndex
,
currIndex
);
opReduce2
,
accuVal
,
currVal
,
accuIndex
,
currIndex
);
};
};
...
@@ -246,7 +247,7 @@ struct ReductionHost
...
@@ -246,7 +247,7 @@ struct ReductionHost
auto
offset_invariant
=
auto
offset_invariant
=
get_offset_from_index
<
NumInvariantDim
>
(
invariantStrides
,
invariant_index
);
get_offset_from_index
<
NumInvariantDim
>
(
invariantStrides
,
invariant_index
);
for
(
IndexDataType
i
=
0
;
i
<
reduce_dim_indexes
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
reduce_dim_indexes
.
size
();
i
++
)
{
{
auto
offset_reduce
=
auto
offset_reduce
=
get_offset_from_index
<
NumReduceDim
>
(
reduceStrides
,
reduce_dim_indexes
[
i
]);
get_offset_from_index
<
NumReduceDim
>
(
reduceStrides
,
reduce_dim_indexes
[
i
]);
...
@@ -256,9 +257,9 @@ struct ReductionHost
...
@@ -256,9 +257,9 @@ struct ReductionHost
preUnaryOp
(
currVal
);
preUnaryOp
(
currVal
);
auto
currIndex
=
i
;
auto
currIndex
=
static_cast
<
IndexDataType
>
(
i
)
;
binop_with_nan_check
2
<
AccDataType
,
PropagateNan
>
(
binop_with_
index_and_
nan_check
<
AccDataType
,
IndexDataType
,
PropagateNan
>
(
opReduce2
,
accuVal
,
currVal
,
accuIndex
,
currIndex
);
opReduce2
,
accuVal
,
currVal
,
accuIndex
,
currIndex
);
};
};
...
...
library/include/ck/library/host_tensor/host_tensor.hpp
View file @
f9c478e2
...
@@ -154,7 +154,7 @@ struct ParallelTensorFunctor
...
@@ -154,7 +154,7 @@ struct ParallelTensorFunctor
{
{
std
::
array
<
std
::
size_t
,
NDIM
>
indices
;
std
::
array
<
std
::
size_t
,
NDIM
>
indices
;
for
(
in
t
idim
=
0
;
idim
<
NDIM
;
++
idim
)
for
(
std
::
size_
t
idim
=
0
;
idim
<
NDIM
;
++
idim
)
{
{
indices
[
idim
]
=
i
/
mStrides
[
idim
];
indices
[
idim
]
=
i
/
mStrides
[
idim
];
i
-=
indices
[
idim
]
*
mStrides
[
idim
];
i
-=
indices
[
idim
]
*
mStrides
[
idim
];
...
@@ -316,7 +316,7 @@ float check_error(const Tensor<T>& ref, const Tensor<T>& result)
...
@@ -316,7 +316,7 @@ float check_error(const Tensor<T>& ref, const Tensor<T>& result)
constexpr
float
eps
=
1e-10
;
constexpr
float
eps
=
1e-10
;
for
(
in
t
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
for
(
std
::
size_
t
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
{
{
float
ref_v
=
ck
::
type_convert
<
float
>
(
ref
.
mData
[
i
]);
float
ref_v
=
ck
::
type_convert
<
float
>
(
ref
.
mData
[
i
]);
float
result_v
=
ck
::
type_convert
<
float
>
(
result
.
mData
[
i
]);
float
result_v
=
ck
::
type_convert
<
float
>
(
result
.
mData
[
i
]);
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp
View file @
f9c478e2
...
@@ -84,7 +84,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
...
@@ -84,7 +84,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator
return
0
;
return
0
;
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp
View file @
f9c478e2
#ifndef REFERENCE_CONV_WRW_HPP
#pragma once
#define REFERENCE_CONV_WRW_HPP
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
...
@@ -16,7 +15,9 @@ template <typename InDataType,
...
@@ -16,7 +15,9 @@ template <typename InDataType,
typename
OutDataType
,
typename
OutDataType
,
typename
InElementwiseOperation
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
typename
OutElementwiseOperation
,
ck
::
index_t
NumDimSpatial
=
2
,
typename
ck
::
enable_if
<
NumDimSpatial
>
=
1
&&
NumDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
ReferenceConvBwdWeight
:
public
device
::
BaseOperator
struct
ReferenceConvBwdWeight
:
public
device
::
BaseOperator
{
{
// Argument
// Argument
...
@@ -32,9 +33,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
...
@@ -32,9 +33,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
OutElementwiseOperation
out_element_op
)
:
in
_n_c_hi_wi
_
{
in_n_c_hi_wi
},
:
in
put
_
{
in_n_c_hi_wi
},
wei
_k_c_y_x
_
{
wei_k_c_y_x
},
wei
ght
_
{
wei_k_c_y_x
},
out
_n_k_ho_wo
_
{
out_n_k_ho_wo
},
out
put
_
{
out_n_k_ho_wo
},
conv_strides_
{
conv_filter_strides
},
conv_strides_
{
conv_filter_strides
},
conv_dilations_
{
conv_filter_dilations
},
conv_dilations_
{
conv_filter_dilations
},
in_left_pads_
{
input_left_pads
},
in_left_pads_
{
input_left_pads
},
...
@@ -45,9 +46,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
...
@@ -45,9 +46,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
{
{
}
}
const
Tensor
<
InDataType
>&
in
_n_c_hi_wi
_
;
const
Tensor
<
InDataType
>&
in
put
_
;
Tensor
<
WeiDataType
>&
wei
_k_c_y_x
_
;
Tensor
<
WeiDataType
>&
wei
ght
_
;
const
Tensor
<
OutDataType
>&
out
_n_k_ho_wo
_
;
const
Tensor
<
OutDataType
>&
out
put
_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
std
::
vector
<
index_t
>
conv_dilations_
;
...
@@ -66,55 +67,184 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
...
@@ -66,55 +67,184 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
float
Run
(
const
Argument
&
arg
)
float
Run
(
const
Argument
&
arg
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
if
constexpr
(
NumDimSpatial
==
1
)
constexpr
auto
I1
=
Number
<
1
>
{};
{
auto
f_kcyx
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
float
v_acc
=
0
;
auto
f_kcx
=
[
&
](
auto
k
,
auto
c
,
auto
x
)
{
for
(
int
n
=
0
;
n
<
arg
.
out_n_k_ho_wo_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
float
v_acc
=
0
;
{
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
for
(
int
ho
=
0
;
ho
<
arg
.
out_n_k_ho_wo_
.
mDesc
.
GetLengths
()[
2
];
++
ho
)
{
{
int
hi
=
ho
*
arg
.
conv_strides_
[
I0
]
+
y
*
arg
.
conv_dilations_
[
I0
]
-
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
++
wo
)
arg
.
in_left_pads_
[
I0
];
for
(
int
wo
=
0
;
wo
<
arg
.
out_n_k_ho_wo_
.
mDesc
.
GetLengths
()[
3
];
++
wo
)
{
{
int
wi
=
wo
*
arg
.
conv_strides_
[
I1
]
+
x
*
arg
.
conv_dilations_
[
I1
]
-
auto
wi
=
arg
.
in_left_pads_
[
I1
];
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
I0
])
+
if
(
hi
>=
0
&&
hi
<
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
I0
])
-
wi
<
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
3
])
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I0
]);
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
{
{
float
v_out
;
float
v_out
;
float
v_in
;
float
v_in
;
arg
.
out_element_op_
(
arg
.
out_element_op_
(
v_out
,
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
k
,
wo
)));
ck
::
type_convert
<
float
>
(
arg
.
out_n_k_ho_wo_
(
n
,
k
,
ho
,
wo
)));
arg
.
in_element_op_
(
v_in
,
arg
.
in_element_op_
(
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
wi
)));
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
in_n_c_hi_wi_
(
n
,
c
,
hi
,
wi
)));
v_acc
+=
v_out
*
v_in
;
v_acc
+=
v_out
*
v_in
;
}
}
}
}
}
}
}
float
v_wei
;
float
v_wei
;
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
wei
_k_c_y_x
_
(
k
,
c
,
y
,
x
)
=
ck
::
type_convert
<
Out
DataType
>
(
v_wei
);
arg
.
wei
ght
_
(
k
,
c
,
x
)
=
ck
::
type_convert
<
Wei
DataType
>
(
v_wei
);
};
};
make_ParallelTensorFunctor
(
f_kcyx
,
make_ParallelTensorFunctor
(
f_kcx
,
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
0
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
],
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
1
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
],
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
2
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
])(
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
std
::
thread
::
hardware_concurrency
());
return
0
;
return
0
;
}
else
if
constexpr
(
NumDimSpatial
==
2
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
auto
f_kcyx
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
{
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
++
ho
)
{
auto
hi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
I0
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
I0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I0
]);
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
++
wo
)
{
auto
wi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
I1
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
I1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I1
]);
if
(
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])
{
float
v_out
;
float
v_in
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
k
,
ho
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
hi
,
wi
)));
v_acc
+=
v_out
*
v_in
;
}
}
}
}
float
v_wei
;
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
weight_
(
k
,
c
,
y
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
};
make_ParallelTensorFunctor
(
f_kcyx
,
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
NumDimSpatial
==
3
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
auto
f_kczyx
=
[
&
](
auto
k
,
auto
c
,
auto
z
,
auto
y
,
auto
x
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
{
for
(
std
::
size_t
do_
=
0
;
do_
<
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
++
do_
)
{
auto
di
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
do_
*
arg
.
conv_strides_
[
I0
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
I0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I0
]);
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
++
ho
)
{
auto
hi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
I1
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
I1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I1
]);
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
mDesc
.
GetLengths
()[
4
];
++
wo
)
{
auto
wi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
I2
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
I2
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I2
]);
if
(
di
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
di
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
4
])
{
float
v_out
;
float
v_in
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
k
,
do_
,
ho
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
di
,
hi
,
wi
)));
v_acc
+=
v_out
*
v_in
;
}
}
}
}
}
float
v_wei
;
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
weight_
(
k
,
c
,
z
,
y
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
};
make_ParallelTensorFunctor
(
f_kczyx
,
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/*stream_config*/
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
...
@@ -174,4 +304,3 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
...
@@ -174,4 +304,3 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
}
// namespace host
}
// namespace host
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
View file @
f9c478e2
...
@@ -78,15 +78,18 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -78,15 +78,18 @@ struct ReferenceConvBwdData : public device::BaseOperator
AccDataType
v_acc
=
0
;
AccDataType
v_acc
=
0
;
for
(
in
t
x
=
0
;
x
<
X
;
++
x
)
for
(
std
::
size_
t
x
=
0
;
x
<
X
;
++
x
)
{
{
int
w_tmp
=
wi
+
arg
.
in_left_pads_
[
0
]
-
x
*
arg
.
conv_dilations_
[
0
];
auto
w_tmp
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wi
)
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
0
]);
if
(
w_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
if
(
w_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
{
{
int
wo
=
w_tmp
/
arg
.
conv_strides_
[
0
];
auto
wo
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
w_tmp
)
/
if
(
wo
>=
0
&&
wo
<
Wo
)
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
0
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
{
for
(
in
t
k
=
0
;
k
<
K
;
++
k
)
for
(
std
::
size_
t
k
=
0
;
k
<
K
;
++
k
)
{
{
AccDataType
v_out
=
0
;
AccDataType
v_out
=
0
;
AccDataType
v_wei
=
0
;
AccDataType
v_wei
=
0
;
...
@@ -128,24 +131,32 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -128,24 +131,32 @@ struct ReferenceConvBwdData : public device::BaseOperator
AccDataType
v_acc
=
0
;
AccDataType
v_acc
=
0
;
for
(
in
t
y
=
0
;
y
<
Y
;
++
y
)
for
(
std
::
size_
t
y
=
0
;
y
<
Y
;
++
y
)
{
{
int
h_tmp
=
hi
+
arg
.
in_left_pads_
[
0
]
-
y
*
arg
.
conv_dilations_
[
0
];
auto
h_tmp
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
hi
)
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
]);
if
(
h_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
if
(
h_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
{
{
int
ho
=
h_tmp
/
arg
.
conv_strides_
[
0
];
auto
ho
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
h_tmp
)
/
if
(
ho
>=
0
&&
ho
<
Ho
)
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
0
]);
if
(
ho
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
ho
)
<
Ho
)
{
{
for
(
in
t
x
=
0
;
x
<
X
;
++
x
)
for
(
std
::
size_
t
x
=
0
;
x
<
X
;
++
x
)
{
{
int
w_tmp
=
auto
w_tmp
=
wi
+
arg
.
in_left_pads_
[
1
]
-
x
*
arg
.
conv_dilations_
[
1
];
ck
::
type_convert
<
ck
::
long_index_t
>
(
wi
)
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
]);
if
(
w_tmp
%
arg
.
conv_strides_
[
1
]
==
0
)
if
(
w_tmp
%
arg
.
conv_strides_
[
1
]
==
0
)
{
{
int
wo
=
w_tmp
/
arg
.
conv_strides_
[
1
];
auto
wo
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
w_tmp
)
/
if
(
wo
>=
0
&&
wo
<
Wo
)
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
1
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
{
for
(
in
t
k
=
0
;
k
<
K
;
++
k
)
for
(
std
::
size_
t
k
=
0
;
k
<
K
;
++
k
)
{
{
AccDataType
v_out
=
0
;
AccDataType
v_out
=
0
;
AccDataType
v_wei
=
0
;
AccDataType
v_wei
=
0
;
...
@@ -194,33 +205,49 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -194,33 +205,49 @@ struct ReferenceConvBwdData : public device::BaseOperator
AccDataType
v_acc
=
0
;
AccDataType
v_acc
=
0
;
for
(
in
t
z
=
0
;
z
<
Z
;
++
z
)
for
(
std
::
size_
t
z
=
0
;
z
<
Z
;
++
z
)
{
{
int
d_tmp
=
di
+
arg
.
in_left_pads_
[
0
]
-
z
*
arg
.
conv_dilations_
[
0
];
auto
d_tmp
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
di
)
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
]);
if
(
d_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
if
(
d_tmp
%
arg
.
conv_strides_
[
0
]
==
0
)
{
{
int
do_
=
d_tmp
/
arg
.
conv_strides_
[
0
];
auto
do_
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
d_tmp
)
/
if
(
do_
>=
0
&&
do_
<
Do
)
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
0
]);
if
(
do_
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
do_
)
<
Do
)
{
{
for
(
in
t
y
=
0
;
y
<
Y
;
++
y
)
for
(
std
::
size_
t
y
=
0
;
y
<
Y
;
++
y
)
{
{
int
h_tmp
=
auto
h_tmp
=
hi
+
arg
.
in_left_pads_
[
1
]
-
y
*
arg
.
conv_dilations_
[
1
];
ck
::
type_convert
<
ck
::
long_index_t
>
(
hi
)
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
]);
if
(
h_tmp
%
arg
.
conv_strides_
[
1
]
==
0
)
if
(
h_tmp
%
arg
.
conv_strides_
[
1
]
==
0
)
{
{
int
ho
=
h_tmp
/
arg
.
conv_strides_
[
1
];
auto
ho
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
h_tmp
)
/
if
(
ho
>=
0
&&
ho
<
Ho
)
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
1
]);
if
(
ho
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
ho
)
<
Ho
)
{
{
for
(
in
t
x
=
0
;
x
<
X
;
++
x
)
for
(
std
::
size_
t
x
=
0
;
x
<
X
;
++
x
)
{
{
int
w_tmp
=
wi
+
arg
.
in_left_pads_
[
2
]
-
auto
w_tmp
=
x
*
arg
.
conv_dilations_
[
2
];
ck
::
type_convert
<
ck
::
long_index_t
>
(
wi
)
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
2
]);
if
(
w_tmp
%
arg
.
conv_strides_
[
2
]
==
0
)
if
(
w_tmp
%
arg
.
conv_strides_
[
2
]
==
0
)
{
{
int
wo
=
w_tmp
/
arg
.
conv_strides_
[
2
];
auto
wo
=
if
(
wo
>=
0
&&
wo
<
Wo
)
ck
::
type_convert
<
ck
::
long_index_t
>
(
w_tmp
)
/
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
conv_strides_
[
2
]);
if
(
wo
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wo
)
<
Wo
)
{
{
for
(
in
t
k
=
0
;
k
<
K
;
++
k
)
for
(
std
::
size_
t
k
=
0
;
k
<
K
;
++
k
)
{
{
AccDataType
v_out
=
0
;
AccDataType
v_out
=
0
;
AccDataType
v_wei
=
0
;
AccDataType
v_wei
=
0
;
...
@@ -264,7 +291,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
...
@@ -264,7 +291,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
}
}
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
f9c478e2
#ifndef REFERENCE_CONV_FWD_HPP
#pragma once
#define REFERENCE_CONV_FWD_HPP
#include <iostream>
#include <iostream>
#include <type_traits>
#include <type_traits>
#include <sstream>
#include <sstream>
#include "stream_config.hpp"
#include "device_base.hpp"
#include "device_base.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
...
@@ -88,13 +89,16 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -88,13 +89,16 @@ struct ReferenceConvFwd : public device::BaseOperator
auto
f_ncw
=
[
&
](
auto
n
,
auto
k
,
auto
wo
)
{
auto
f_ncw
=
[
&
](
auto
n
,
auto
k
,
auto
wo
)
{
float
v_acc
=
0
;
float
v_acc
=
0
;
for
(
in
t
c
=
0
;
c
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_
t
c
=
0
;
c
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
{
{
for
(
in
t
x
=
0
;
x
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
++
x
)
for
(
std
::
size_
t
x
=
0
;
x
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
++
x
)
{
{
int
wi
=
wo
*
arg
.
conv_strides_
[
0
]
+
x
*
arg
.
conv_dilations_
[
0
]
-
auto
wi
=
arg
.
in_left_pads_
[
0
];
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
0
])
+
if
(
wi
>=
0
&&
wi
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
{
{
float
v_in
;
float
v_in
;
float
v_wei
;
float
v_wei
;
...
@@ -128,18 +132,26 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -128,18 +132,26 @@ struct ReferenceConvFwd : public device::BaseOperator
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
float
v_acc
=
0
;
for
(
in
t
c
=
0
;
c
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_
t
c
=
0
;
c
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
{
{
for
(
in
t
y
=
0
;
y
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
++
y
)
for
(
std
::
size_
t
y
=
0
;
y
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
++
y
)
{
{
int
hi
=
ho
*
arg
.
conv_strides_
[
0
]
+
y
*
arg
.
conv_dilations_
[
0
]
-
auto
hi
=
arg
.
in_left_pads_
[
0
];
ck
::
type_convert
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
for
(
int
x
=
0
;
x
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
++
x
)
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
++
x
)
{
{
int
wi
=
wo
*
arg
.
conv_strides_
[
1
]
+
x
*
arg
.
conv_dilations_
[
1
]
-
auto
wi
=
arg
.
in_left_pads_
[
1
];
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
if
(
hi
>=
0
&&
hi
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
])
-
wi
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
if
(
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])
{
{
float
v_in
;
float
v_in
;
float
v_wei
;
float
v_wei
;
...
@@ -174,23 +186,37 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -174,23 +186,37 @@ struct ReferenceConvFwd : public device::BaseOperator
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
float
v_acc
=
0
;
for
(
in
t
c
=
0
;
c
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_
t
c
=
0
;
c
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
{
{
for
(
in
t
z
=
0
;
z
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
++
z
)
for
(
std
::
size_
t
z
=
0
;
z
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
];
++
z
)
{
{
int
di
=
d_o
*
arg
.
conv_strides_
[
0
]
+
z
*
arg
.
conv_dilations_
[
0
]
-
auto
di
=
arg
.
in_left_pads_
[
0
];
ck
::
type_convert
<
ck
::
long_index_t
>
(
d_o
*
arg
.
conv_strides_
[
0
])
+
for
(
int
y
=
0
;
y
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
++
y
)
ck
::
type_convert
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
std
::
size_t
y
=
0
;
y
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
];
++
y
)
{
{
int
hi
=
ho
*
arg
.
conv_strides_
[
1
]
+
y
*
arg
.
conv_dilations_
[
1
]
-
auto
hi
=
arg
.
in_left_pads_
[
1
];
ck
::
type_convert
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
1
])
+
for
(
int
x
=
0
;
x
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
];
++
x
)
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
];
++
x
)
{
{
int
wi
=
wo
*
arg
.
conv_strides_
[
2
]
+
auto
wi
=
x
*
arg
.
conv_dilations_
[
2
]
-
arg
.
in_left_pads_
[
2
];
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
if
(
di
>=
0
&&
di
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
arg
.
conv_strides_
[
2
])
+
hi
>=
0
&&
hi
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
]
&&
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
wi
>=
0
&&
wi
<
arg
.
input_
.
mDesc
.
GetLengths
()[
4
])
arg
.
conv_dilations_
[
2
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
]);
if
(
di
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
di
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
4
])
{
{
float
v_in
;
float
v_in
;
float
v_wei
;
float
v_wei
;
...
@@ -226,7 +252,8 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -226,7 +252,8 @@ struct ReferenceConvFwd : public device::BaseOperator
}
}
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/*stream_config*/
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
...
@@ -286,4 +313,3 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -286,4 +313,3 @@ struct ReferenceConvFwd : public device::BaseOperator
}
// namespace host
}
// namespace host
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp
View file @
f9c478e2
...
@@ -73,18 +73,25 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
...
@@ -73,18 +73,25 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
float
v_acc
=
0
;
for
(
in
t
c
=
0
;
c
<
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_
t
c
=
0
;
c
<
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
{
{
for
(
in
t
y
=
0
;
y
<
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
2
];
++
y
)
for
(
std
::
size_
t
y
=
0
;
y
<
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
2
];
++
y
)
{
{
int
hi
=
ho
*
arg
.
conv_strides_
[
0
]
+
y
*
arg
.
conv_dilations_
[
0
]
-
auto
hi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
arg
.
in_left_pads_
[
0
];
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
])
-
for
(
int
x
=
0
;
x
<
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
3
];
++
x
)
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
3
];
++
x
)
{
{
int
wi
=
wo
*
arg
.
conv_strides_
[
1
]
+
x
*
arg
.
conv_dilations_
[
1
]
-
auto
wi
=
arg
.
in_left_pads_
[
1
];
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
if
(
hi
>=
0
&&
hi
<
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
])
-
wi
<
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
3
])
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
if
(
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
3
])
{
{
float
v_in
;
float
v_in
;
float
v_wei
;
float
v_wei
;
...
@@ -117,7 +124,8 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
...
@@ -117,7 +124,8 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator
return
0
;
return
0
;
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp
View file @
f9c478e2
...
@@ -76,18 +76,25 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
...
@@ -76,18 +76,25 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
auto
f_nchw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
float
v_acc
=
0
;
for
(
in
t
c
=
0
;
c
<
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_
t
c
=
0
;
c
<
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
1
];
++
c
)
{
{
for
(
in
t
y
=
0
;
y
<
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
2
];
++
y
)
for
(
std
::
size_
t
y
=
0
;
y
<
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
2
];
++
y
)
{
{
int
hi
=
ho
*
arg
.
conv_strides_
[
0
]
+
y
*
arg
.
conv_dilations_
[
0
]
-
auto
hi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
arg
.
in_left_pads_
[
0
];
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
])
-
for
(
int
x
=
0
;
x
<
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
3
];
++
x
)
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
3
];
++
x
)
{
{
int
wi
=
wo
*
arg
.
conv_strides_
[
1
]
+
x
*
arg
.
conv_dilations_
[
1
]
-
auto
wi
=
arg
.
in_left_pads_
[
1
];
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
if
(
hi
>=
0
&&
hi
<
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
1
])
-
wi
<
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
3
])
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
if
(
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
3
])
{
{
float
v_in
;
float
v_in
;
float
v_wei
;
float
v_wei
;
...
@@ -123,7 +130,8 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
...
@@ -123,7 +130,8 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator
return
0
;
return
0
;
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/*stream_config*/
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
f9c478e2
#ifndef REFERENCE_GEMM_HPP
#pragma once
#define REFERENCE_GEMM_HPP
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "device_base.hpp"
#include "device_base.hpp"
...
@@ -13,6 +11,7 @@ namespace host {
...
@@ -13,6 +11,7 @@ namespace host {
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
>
...
@@ -55,20 +54,20 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -55,20 +54,20 @@ struct ReferenceGemm : public device::BaseOperator
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
const
int
K
=
arg
.
a_m_k_
.
mDesc
.
GetLengths
()[
1
];
float
v_acc
=
0
;
AccDataType
v_acc
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
float
v_a
;
AccDataType
v_a
;
float
v_b
;
AccDataType
v_b
;
arg
.
a_element_op_
(
v_a
,
static_cast
<
const
float
>
(
arg
.
a_m_k_
(
m
,
k
)));
arg
.
a_element_op_
(
v_a
,
static_cast
<
const
AccDataType
>
(
arg
.
a_m_k_
(
m
,
k
)));
arg
.
b_element_op_
(
v_b
,
static_cast
<
const
float
>
(
arg
.
b_k_n_
(
k
,
n
)));
arg
.
b_element_op_
(
v_b
,
static_cast
<
const
AccDataType
>
(
arg
.
b_k_n_
(
k
,
n
)));
v_acc
+=
v_a
*
v_b
;
v_acc
+=
v_a
*
v_b
;
}
}
float
v_c
;
AccDataType
v_c
;
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_element_op_
(
v_c
,
v_acc
);
...
@@ -82,7 +81,8 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -82,7 +81,8 @@ struct ReferenceGemm : public device::BaseOperator
return
0
;
return
0
;
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
...
@@ -129,4 +129,3 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -129,4 +129,3 @@ struct ReferenceGemm : public device::BaseOperator
}
// namespace host
}
// namespace host
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
#endif
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp
View file @
f9c478e2
...
@@ -82,7 +82,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator
...
@@ -82,7 +82,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator
return
0
;
return
0
;
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp
View file @
f9c478e2
...
@@ -85,7 +85,8 @@ struct ReferenceGemmBiasActivation : public device::BaseOperator
...
@@ -85,7 +85,8 @@ struct ReferenceGemmBiasActivation : public device::BaseOperator
return
0
;
return
0
;
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp
View file @
f9c478e2
...
@@ -91,7 +91,8 @@ struct ReferenceGemmBiasActivationAdd : public device::BaseOperator
...
@@ -91,7 +91,8 @@ struct ReferenceGemmBiasActivationAdd : public device::BaseOperator
return
0
;
return
0
;
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
int
)
override
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp
View file @
f9c478e2
...
@@ -9,26 +9,11 @@
...
@@ -9,26 +9,11 @@
#include "device_reduce_instance_blockwise_i8_i8_i8.hpp"
#include "device_reduce_instance_blockwise_i8_i8_i8.hpp"
#include "device_reduce_instance_blockwise_i8_i32_i8.hpp"
#include "device_reduce_instance_blockwise_i8_i32_i8.hpp"
#include "device_reduce_instance_blockwise_b16_f32_b16.hpp"
#include "device_reduce_instance_blockwise_b16_f32_b16.hpp"
#include "device_reduce_instance_blockwise_second_call_f16_f16_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f64.hpp"
#include "device_reduce_instance_blockwise_second_call_i8_i8_i8.hpp"
#include "device_reduce_instance_blockwise_second_call_i32_i32_i8.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_b16.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp"
#include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.hpp"
#include "device_reduce_instance_threadwise_f16_f16_f16.hpp"
#include "device_reduce_instance_threadwise_f16_f16_f16.hpp"
#include "device_reduce_instance_threadwise_f16_f32_f16.hpp"
#include "device_reduce_instance_threadwise_f16_f32_f16.hpp"
#include "device_reduce_instance_threadwise_f32_f32_f32.hpp"
#include "device_reduce_instance_threadwise_f32_f32_f32.hpp"
...
...
library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp
View file @
f9c478e2
...
@@ -3,13 +3,27 @@
...
@@ -3,13 +3,27 @@
#include "reduction_operator_mapping.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_impl_common.hpp"
#include "device_reduce_instance_impl_common.hpp"
#include "device_reduce_block
wise
.hpp"
#include "device_reduce_
multi
block.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
device_reduce_instance
{
namespace
device_reduce_instance
{
using
reduce_configuration_1_instances_blockwise
=
std
::
tuple
<
// clang-format off
// BlockSize | MThreadClusterSize | KThreadClusterSize
ReductionConfiguration_1
<
256
,
128
,
2
>
,
ReductionConfiguration_1
<
256
,
64
,
4
>
,
ReductionConfiguration_1
<
256
,
32
,
8
>
,
ReductionConfiguration_1
<
256
,
16
,
16
>
,
ReductionConfiguration_1
<
256
,
8
,
32
>
,
ReductionConfiguration_1
<
256
,
4
,
64
>
,
ReductionConfiguration_1
<
256
,
2
,
128
>
,
ReductionConfiguration_1
<
256
,
1
,
256
>
// clang-format on
>
;
#ifdef QUICK_REDUCE_TEST
#ifdef QUICK_REDUCE_TEST
using
reduce_configuration_2_instances_blockwise
=
std
::
tuple
<
using
reduce_configuration_2_instances_blockwise
=
std
::
tuple
<
// clang-format off
// clang-format off
...
@@ -58,8 +72,8 @@ template <typename InDataType,
...
@@ -58,8 +72,8 @@ template <typename InDataType,
int
Rank
,
int
Rank
,
int
NumReduceDim
,
int
NumReduceDim
,
ReduceTensorOp
ReduceOpId
,
ReduceTensorOp
ReduceOpId
,
Nan
Propagat
ion
NanOpt
,
bool
Propagat
eNan
,
ReduceTensorIndices
IndicesOpt
>
bool
UseIndex
>
void
add_device_reduce_instance_blockwise
(
void
add_device_reduce_instance_blockwise
(
std
::
vector
<
deviceReduceBlockWisePtrType
<
AccDataType
,
ReduceOpId
>>&
device_op_instances
)
std
::
vector
<
deviceReduceBlockWisePtrType
<
AccDataType
,
ReduceOpId
>>&
device_op_instances
)
{
{
...
@@ -73,92 +87,94 @@ void add_device_reduce_instance_blockwise(
...
@@ -73,92 +87,94 @@ void add_device_reduce_instance_blockwise(
constexpr
bool
Indexable
=
constexpr
bool
Indexable
=
(
ReduceOpId
==
ReduceTensorOp
::
MIN
||
ReduceOpId
==
ReduceTensorOp
::
MAX
||
(
ReduceOpId
==
ReduceTensorOp
::
MIN
||
ReduceOpId
==
ReduceTensorOp
::
MAX
||
ReduceOpId
==
ReduceTensorOp
::
AMAX
);
ReduceOpId
==
ReduceTensorOp
::
AMAX
);
constexpr
bool
NeedIndices
=
Indexable
&&
(
IndicesOpt
!=
ReduceTensorIndices
::
NO_INDICES
);
constexpr
bool
OutputIndex
=
Indexable
&&
UseIndex
;
constexpr
bool
PropagateNan
=
(
NanOpt
==
NanPropagation
::
NOT_PROPAGATE_NAN
)
?
false
:
true
;
static_for
<
0
,
std
::
tuple_size
<
reduce_configuration_1_instances_blockwise
>::
value
,
1
>
{}(
[
&
](
auto
i
)
{
static_for
<
0
,
std
::
tuple_size
<
reduce_configuration_1_instances
>::
value
,
1
>
{}([
&
](
auto
i
)
{
using
cfg1
=
remove_cvref_t
<
decltype
(
using
cfg1
=
std
::
get
<
i
.
value
>
(
reduce_configuration_1_instances_blockwise
{}))
>
;
remove_cvref_t
<
decltype
(
std
::
get
<
i
.
value
>
(
reduce_configuration_1_instances
{}))
>
;
static_for
<
0
,
std
::
tuple_size
<
reduce_configuration_2_instances_blockwise
>::
value
,
1
>
{}(
static_for
<
0
,
std
::
tuple_size
<
reduce_configuration_2_instances_blockwise
>::
value
,
1
>
{}(
[
&
](
auto
j
)
{
[
&
](
auto
j
)
{
using
cfg2
=
remove_cvref_t
<
decltype
(
using
cfg2
=
remove_cvref_t
<
decltype
(
std
::
get
<
j
.
value
>
(
reduce_configuration_2_instances_blockwise
{}))
>
;
std
::
get
<
j
.
value
>
(
reduce_configuration_2_instances_blockwise
{}))
>
;
using
ReduceOpInstance
=
using
ReduceOpInstance
=
DeviceReduceBlockWise
<
InDataType
,
DeviceReduceMultiBlock
<
InDataType
,
AccDataType
,
AccDataType
,
OutDataType
,
OutDataType
,
Rank
,
Rank
,
NumReduceDim
,
NumReduceDim
,
ReduceOperation
,
ReduceOperation
,
InElementwiseOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
InMemoryDataOperationEnum
::
Set
,
NeedIndices
,
PropagateNan
,
cfg1
::
BlockSize_
,
OutputIndex
,
cfg1
::
MThreadClusterSize_
,
false
,
// HaveIndexInputIfOutputIndex
cfg1
::
KThreadClusterSize_
,
cfg1
::
BlockSize_
,
cfg2
::
MThreadSliceSize_
,
cfg1
::
MThreadClusterSize_
,
cfg2
::
KThreadSliceSize_
,
cfg1
::
KThreadClusterSize_
,
cfg2
::
InSrcVectorDim_
,
cfg2
::
MThreadSliceSize_
,
cfg2
::
InSrcVectorSize_
,
cfg2
::
KThreadSliceSize_
,
cfg2
::
OutDstVectorSize_
>
;
cfg2
::
InSrcVectorDim_
,
cfg2
::
InSrcVectorSize_
,
device_op_instances
.
push_back
(
cfg2
::
OutDstVectorSize_
>
;
std
::
make_unique
<
ReduceOpInstance
>
(
ReduceOpInstance
{}));
});
device_op_instances
.
push_back
(
});
std
::
make_unique
<
ReduceOpInstance
>
(
ReduceOpInstance
{}));
});
});
};
};
#define ADD_BLOCKWISE_INST_BY_TYPE( \
#define ADD_BLOCKWISE_INST_BY_TYPE(
\
inT, compT, outT, ReduceOpId,
NanOpt, IndicesOpt
, Rank, NumReduceDim) \
inT, compT, outT, ReduceOpId,
PropagateNan, UseIndex
, Rank, NumReduceDim) \
template void add_device_reduce_instance_blockwise<inT, \
template void add_device_reduce_instance_blockwise<inT,
\
compT, \
compT,
\
outT, \
outT,
\
Rank, \
Rank,
\
NumReduceDim, \
NumReduceDim,
\
ReduceOpId, \
ReduceOpId,
\
NanOpt,
\
PropagateNan,
\
IndicesOpt>(
\
UseIndex>(
\
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID(
\
#define ADD_BLOCKWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim)
\
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_BY_TYPE(inT,
\
ADD_BLOCKWISE_INST_BY_TYPE(inT, \
compT,
\
compT, \
outT,
\
outT, \
static_cast<ReduceTensorOp>(ReduceOpId),
\
static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<
NanPropagation
>(NanOpt), \
static_cast<
bool
>(NanOpt),
\
static_cast<
ReduceTensorIndices
>(IndicesOpt), \
static_cast<
bool
>(IndicesOpt),
\
Rank,
\
Rank, \
NumReduceDim)
NumReduceDim)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId,
NanOpt, IndicesOpt
, Rank, NumReduceDim)
\
inT, compT, outT, ReduceOpId,
PropagateNan, UseIndex
, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \
extern template void add_device_reduce_instance_blockwise<inT, \
compT, \
compT, \
outT, \
outT, \
Rank, \
Rank, \
NumReduceDim, \
NumReduceDim, \
ReduceOpId, \
ReduceOpId, \
NanOpt,
\
PropagateNan,
\
IndicesOpt>(
\
UseIndex>(
\
std::vector<DeviceReducePtr< \
std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
AccElementwiseOperation>> & \
device_op_instances)
device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID(
\
#define ADD_BLOCKWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim)
\
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT,
\
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
compT,
\
compT, \
outT,
\
outT, \
static_cast<ReduceTensorOp>(ReduceOpId),
\
static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<
NanPropagation
>(NanOpt), \
static_cast<
bool
>(NanOpt),
\
static_cast<
ReduceTensorIndices
>(IndicesOpt), \
static_cast<
bool
>(IndicesOpt),
\
Rank,
\
Rank, \
NumReduceDim)
NumReduceDim)
}
// namespace device_reduce_instance
}
// namespace device_reduce_instance
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment