Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7a3b49e5
Commit
7a3b49e5
authored
Jun 25, 2022
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into contraction
parents
e07b3d8e
d3051d75
Changes
592
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
207 additions
and
2305 deletions
+207
-2305
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+14
-0
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+5
-4
library/include/ck/library/host/host_interface.hpp
library/include/ck/library/host/host_interface.hpp
+0
-54
library/include/ck/library/host_tensor/conv_common.hpp
library/include/ck/library/host_tensor/conv_common.hpp
+5
-18
library/include/ck/library/host_tensor/device_memory.hpp
library/include/ck/library/host_tensor/device_memory.hpp
+40
-0
library/include/ck/library/host_tensor/device_tensor.hpp
library/include/ck/library/host_tensor/device_tensor.hpp
+0
-8
library/include/ck/library/host_tensor/host_common_util.hpp
library/include/ck/library/host_tensor/host_common_util.hpp
+7
-33
library/include/ck/library/host_tensor/host_conv.hpp
library/include/ck/library/host_tensor/host_conv.hpp
+3
-0
library/include/ck/library/host_tensor/host_gemm.hpp
library/include/ck/library/host_tensor/host_gemm.hpp
+4
-0
library/include/ck/library/host_tensor/host_reduce_util.hpp
library/include/ck/library/host_tensor/host_reduce_util.hpp
+0
-257
library/include/ck/library/host_tensor/host_reduction.hpp
library/include/ck/library/host_tensor/host_reduction.hpp
+51
-77
library/include/ck/library/host_tensor/host_tensor.hpp
library/include/ck/library/host_tensor/host_tensor.hpp
+72
-7
library/include/ck/library/host_tensor/host_tensor_generator.hpp
.../include/ck/library/host_tensor/host_tensor_generator.hpp
+6
-3
library/include/ck/library/obselete_driver_offline/debug.hpp
library/include/ck/library/obselete_driver_offline/debug.hpp
+0
-13
library/include/ck/library/obselete_driver_offline/device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
...ward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
+0
-220
library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
...ackward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
+0
-309
library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
...kward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
+0
-423
library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp
...d_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp
+0
-389
library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp
...ght_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp
+0
-256
library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
...ard_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
+0
-234
No files found.
include/ck/utility/tuple_helper.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "functional4.hpp"
...
...
@@ -19,6 +22,17 @@ __host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
template
<
typename
...
X
,
typename
...
Y
>
__host__
__device__
constexpr
auto
concat_tuple_of_reference
(
const
Tuple
<
X
&
...
>&
tx
,
const
Tuple
<
Y
&
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
ty
);
}
namespace
detail
{
template
<
typename
F
,
typename
X
,
index_t
...
Is
>
...
...
include/ck/utility/type.hpp
View file @
7a3b49e5
#ifndef CK_TYPE_HPP
#define CK_TYPE_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "config.hpp"
#pragma once
#include "ck/ck.hpp"
#include "integral_constant.hpp"
#include "enable_if.hpp"
...
...
@@ -56,4 +58,3 @@ __host__ __device__ constexpr Y bit_cast(const X& x)
}
}
// namespace ck
#endif
library/include/ck/library/host/host_interface.hpp
deleted
100644 → 0
View file @
e07b3d8e
#pragma once
#include <memory>
#include <string>
#include "stream_config.hpp"
#include "config.hpp"
#include "device_base.hpp"
struct
DeviceConvFwdPtr_t
{
using
BaseArgument
=
ck
::
tensor_operation
::
device
::
BaseArgument
;
using
BaseInvoker
=
ck
::
tensor_operation
::
device
::
BaseInvoker
;
struct
DeviceConvFwdPtrImpl
;
std
::
unique_ptr
<
DeviceConvFwdPtrImpl
>
pImpl
;
DeviceConvFwdPtr_t
();
~
DeviceConvFwdPtr_t
();
DeviceConvFwdPtr_t
(
DeviceConvFwdPtr_t
&&
);
DeviceConvFwdPtr_t
(
DeviceConvFwdPtrImpl
&
);
DeviceConvFwdPtr_t
&
operator
=
(
DeviceConvFwdPtr_t
&
)
=
delete
;
DeviceConvFwdPtr_t
&
operator
=
(
const
DeviceConvFwdPtr_t
&
)
=
delete
;
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
in_ptr
,
void
*
wei_ptr
,
void
*
out_ptr
,
size_t
N
,
size_t
K
,
size_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
const
;
// in,wei and out element ops are ignored for now since even if we change them, they
// cant be linked
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
const
;
// requires including BaseInvoker headers
std
::
string
GetTypeString
();
bool
IsSupportedArgument
(
const
BaseArgument
*
arg_ptr
);
};
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
library/include/ck/library/host_tensor/conv_common.hpp
View file @
7a3b49e5
#ifndef CONV_COMMON_HPP
#define CONV_COMMON_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "tensor_descriptor.hpp"
#pragma once
#include "ck/tensor_description/tensor_descriptor.hpp"
template
<
typename
...
InDesc
,
typename
...
WeiDesc
,
...
...
@@ -73,18 +75,3 @@ calculate_convolution_flops(const InDesc&, const WeiDesc& wei_desc, const OutDes
return
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
;
}
template
<
typename
T
>
inline
auto
activ
(
T
v
,
const
ck
::
ActivTypeEnum
activ_type
)
{
const
T
alpha
=
0.3
;
switch
(
activ_type
)
{
case
ck
::
ActivTypeEnum
::
None
:
return
v
;
case
ck
::
ActivTypeEnum
::
LeakyRelu
:
return
(
v
>=
0
?
v
:
alpha
*
v
);
case
ck
::
ActivTypeEnum
::
Sigmoid
:
return
(
1
/
(
1
+
exp
(
-
v
)));
default:
throw
std
::
runtime_error
(
"unsupported activ type"
);
break
;
}
}
#endif
library/include/ck/library/host_tensor/device_memory.hpp
0 → 100644
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
template
<
typename
T
>
__global__
void
set_buffer_value
(
T
*
p
,
T
x
,
uint64_t
buffer_element_size
)
{
for
(
uint64_t
i
=
threadIdx
.
x
;
i
<
buffer_element_size
;
i
+=
blockDim
.
x
)
{
p
[
i
]
=
x
;
}
}
struct
DeviceMem
{
DeviceMem
()
=
delete
;
DeviceMem
(
std
::
size_t
mem_size
);
void
*
GetDeviceBuffer
();
std
::
size_t
GetBufferSize
();
void
ToDevice
(
const
void
*
p
);
void
FromDevice
(
void
*
p
);
void
SetZero
();
template
<
typename
T
>
void
SetValue
(
T
x
)
{
if
(
mMemSize
%
sizeof
(
T
)
!=
0
)
{
throw
std
::
runtime_error
(
"wrong! not entire DeviceMem will be set"
);
}
set_buffer_value
<
T
><<<
1
,
1024
>>>
(
static_cast
<
T
*>
(
mpDeviceBuf
),
x
,
mMemSize
/
sizeof
(
T
));
}
~
DeviceMem
();
void
*
mpDeviceBuf
;
std
::
size_t
mMemSize
;
};
library/include/ck/library/host_tensor/device_tensor.hpp
deleted
100644 → 0
View file @
e07b3d8e
#pragma once
#include "host_tensor.hpp"
template
<
typename
TensorDesc
>
void
ostream_tensor_descriptor
(
TensorDesc
,
std
::
ostream
&
os
=
std
::
cout
)
{
ostream_HostTensorDescriptor
(
make_HostTensorDescriptor
(
TensorDesc
{}),
os
);
}
library/include/ck/library/host_tensor/host_common_util.hpp
View file @
7a3b49e5
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GUARD_HOST_COMMON_UTIL_HPP
#define GUARD_HOST_COMMON_UTIL_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <iostream>
#include <fstream>
#include <string>
#include "c
onfig
.hpp"
#include "c
k/ck
.hpp"
namespace
ck
{
...
...
@@ -95,8 +72,5 @@ static inline std::vector<T> getTypeValuesFromString(const char* cstr_values)
return
(
values
);
}
};
// namespace host_common
};
// namespace ck
#endif
}
// namespace host_common
}
// namespace ck
library/include/ck/library/host_tensor/host_conv.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "host_tensor.hpp"
#include "conv_common.hpp"
...
...
library/include/ck/library/host_tensor/host_gemm.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "host_tensor.hpp"
template
<
typename
AType
,
...
...
library/include/ck/library/host_tensor/host_reduce_util.hpp
deleted
100644 → 0
View file @
e07b3d8e
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef GUARD_HOST_REDUCE_UTIL_HPP
#define GUARD_HOST_REDUCE_UTIL_HPP
#include <limits>
#include <cmath>
#include <functional>
#include "reduction_enums.hpp"
#include "data_type.hpp"
#include "math_v2.hpp"
namespace
ck
{
namespace
host_reduce
{
using
ck
::
NanPropagation
;
using
ck
::
ReduceTensorOp
;
template
<
typename
AccDataType
,
ReduceTensorOp
ReduceOpId
>
__host__
static
inline
std
::
function
<
void
(
AccDataType
&
)
>
PreUnaryOpFn
(
int
)
{
using
ck
::
math
::
abs
;
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
NORM1
)
{
return
([
&
](
AccDataType
&
a_
)
{
a_
=
abs
(
a_
);
});
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
NORM2
)
{
return
([
&
](
AccDataType
&
a_
)
{
a_
=
a_
*
a_
;
});
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
AMAX
)
{
return
([
&
](
AccDataType
&
a_
)
{
a_
=
abs
(
a_
);
});
}
else
{
// ReduceTensorOp::AVG:
// ReduceTensorOp::ADD:
// ReduceTensorOp::MUL:
// ReduceTensorOp::MIN:
// ReduceTensorOp::MAX:
return
([
&
](
AccDataType
&
)
{});
};
};
template
<
typename
AccDataType
,
ReduceTensorOp
ReduceOpId
>
__host__
static
inline
std
::
function
<
void
(
AccDataType
&
)
>
PosUnaryOpFn
(
int32_t
divider
)
{
using
std
::
sqrt
;
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
NORM2
)
{
return
([
&
](
AccDataType
&
a_
)
{
a_
=
sqrt
(
a_
);
});
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
AVG
)
{
return
([
&
,
divider
](
AccDataType
&
a_
)
{
a_
=
a_
/
static_cast
<
AccDataType
>
(
static_cast
<
float
>
(
divider
));
});
}
else
{
// ReduceTensorOp::ADD:
// ReduceTensorOp::NORM1:
// ReduceTensorOp::MUL:
// ReduceTensorOp::MIN:
// ReduceTensorOp::MAX:
// ReduceTensorOp::AMAX:
return
([
&
](
AccDataType
&
)
{});
}
};
template
<
typename
AccDataType
,
ReduceTensorOp
ReduceOpId
>
__host__
static
inline
std
::
function
<
void
(
AccDataType
&
,
AccDataType
)
>
ReduceOpFn
()
{
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
ADD
||
ReduceOpId
==
ReduceTensorOp
::
AVG
||
ReduceOpId
==
ReduceTensorOp
::
NORM1
||
ReduceOpId
==
ReduceTensorOp
::
NORM2
)
{
return
([
&
](
AccDataType
&
a_
,
AccDataType
b_
)
{
a_
=
a_
+
b_
;
});
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
MUL
)
{
return
([
&
](
AccDataType
&
a_
,
AccDataType
b_
)
{
a_
=
a_
*
b_
;
});
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
MIN
)
{
return
([
&
](
AccDataType
&
a_
,
AccDataType
b_
)
{
if
(
a_
>
b_
)
a_
=
b_
;
});
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
MAX
||
ReduceOpId
==
ReduceTensorOp
::
AMAX
)
{
return
([
&
](
AccDataType
&
a_
,
AccDataType
b_
)
{
if
(
a_
<
b_
)
a_
=
b_
;
});
}
};
template
<
typename
AccDataType
,
ReduceTensorOp
ReduceOpId
>
__host__
static
inline
std
::
function
<
void
(
AccDataType
&
,
AccDataType
,
bool
&
changed
)
>
ReduceOpFn2
()
{
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
MIN
)
{
return
([
&
](
AccDataType
&
a_
,
AccDataType
b_
,
bool
&
changed
)
{
if
(
a_
>
b_
)
{
a_
=
b_
;
changed
=
true
;
}
else
changed
=
false
;
});
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
MAX
||
ReduceOpId
==
ReduceTensorOp
::
AMAX
)
{
return
([
&
](
AccDataType
&
a_
,
AccDataType
b_
,
bool
&
changed
)
{
if
(
a_
<
b_
)
{
a_
=
b_
;
changed
=
true
;
}
else
changed
=
false
;
});
}
else
{
// ReduceTensorOp::ADD:
// ReduceTensorOp::MUL:
// ReduceTensorOp::AVG:
// ReduceTensorOp::NORM1:
// ReduceTensorOp::NORM2:
return
(
std
::
function
<
void
(
AccDataType
&
,
AccDataType
,
bool
&
)
>
{});
};
};
template
<
typename
AccDataType
,
ReduceTensorOp
ReduceOpId
>
__host__
static
inline
AccDataType
ReduceOpZeroVal
()
{
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
MUL
)
{
return
(
static_cast
<
AccDataType
>
(
1.0
f
));
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
MIN
)
{
return
(
ck
::
NumericLimits
<
AccDataType
>::
Max
());
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
MAX
)
{
return
(
ck
::
NumericLimits
<
AccDataType
>::
Lowest
());
}
else
if
constexpr
(
ReduceOpId
==
ReduceTensorOp
::
AMAX
)
{
return
(
static_cast
<
AccDataType
>
(
0.0
f
));
}
else
{
// ReduceTensorOp::ADD
// ReduceTensorOp::AVG
// ReduceTensorOp::NORM1
// ReduceTensorOp::NORM2
return
(
static_cast
<
AccDataType
>
(
0.0
f
));
};
};
template
<
typename
AccDataType
,
bool
PropagateNan
>
__host__
static
inline
void
binop_with_nan_check
(
std
::
function
<
void
(
AccDataType
&
,
AccDataType
)
>
opReduce
,
AccDataType
&
accuVal
,
AccDataType
currVal
)
{
using
ck
::
math
::
isnan
;
if
constexpr
(
!
PropagateNan
)
{
opReduce
(
accuVal
,
currVal
);
}
else
{
if
(
isnan
(
currVal
))
accuVal
=
currVal
;
else
opReduce
(
accuVal
,
currVal
);
};
};
template
<
typename
AccDataType
,
typename
IndexDataType
,
bool
PropagateNan
>
__host__
static
inline
void
binop_with_index_and_nan_check
(
std
::
function
<
void
(
AccDataType
&
,
AccDataType
,
bool
&
)
>
opReduce
,
AccDataType
&
accuVal
,
AccDataType
currVal
,
IndexDataType
&
accuIndex
,
IndexDataType
currIndex
)
{
using
ck
::
math
::
isnan
;
if
constexpr
(
!
PropagateNan
)
{
bool
changed
;
opReduce
(
accuVal
,
currVal
,
changed
);
if
(
changed
)
accuIndex
=
currIndex
;
}
else
{
if
(
isnan
(
currVal
))
{
accuVal
=
currVal
;
accuIndex
=
currIndex
;
}
else
{
bool
changed
;
opReduce
(
accuVal
,
currVal
,
changed
);
if
(
changed
)
accuIndex
=
currIndex
;
};
};
};
};
// namespace host_reduce
};
// namespace ck
#endif
library/include/ck/library/host_tensor/host_reduction.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef HOST_REDUCTION_HPP_
#define HOST_REDUCTION_HPP_
#pragma once
#include <vector>
#include <array>
#include <functional>
#include "
reduction_enums
.hpp"
#include "reduction_
common
.hpp"
#include "
host_reduce_util
.hpp"
#include "
host_common_util
.hpp"
#include "
host_tensor
.hpp"
#include "
data_type
.hpp"
#include "
ck/utility/data_type
.hpp"
#include "
ck/utility/
reduction_
enums
.hpp"
#include "
ck/utility/reduction_common
.hpp"
#include "
ck/utility/reduction_functions_accumulate
.hpp"
#include "
ck/library/host_tensor/host_common_util
.hpp"
#include "
ck/library/host_tensor/host_tensor
.hpp"
template
<
int
NDim
>
static
void
get_all_indexes
(
const
std
::
array
<
size_t
,
NDim
>&
dimLengths
,
...
...
@@ -106,11 +82,13 @@ static size_t get_offset_from_index(const std::vector<size_t>& strides,
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
ck
::
ReduceTensorOp
ReduceOpId
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
,
int
Rank
,
int
NumReduceDim
,
bool
PropagateNan
,
bool
NeedIndices
>
bool
OutputIndex
>
struct
ReductionHost
{
using
IndexDataType
=
int32_t
;
...
...
@@ -122,8 +100,6 @@ struct ReductionHost
std
::
vector
<
int
>
reduceDims
;
IndexDataType
divider
;
std
::
function
<
void
(
AccDataType
&
)
>
preUnaryOp
;
std
::
function
<
void
(
AccDataType
&
)
>
posUnaryOp
;
std
::
array
<
size_t
,
NumReduceDim
>
reduceLengths
;
std
::
array
<
size_t
,
NumReduceDim
>
reduceStrides
;
std
::
array
<
size_t
,
NumInvariantDim
>
invariantLengths
;
...
...
@@ -137,9 +113,6 @@ struct ReductionHost
const
std
::
vector
<
int
>&
invariantDims_
,
const
std
::
vector
<
int
>&
reduceDims_
)
{
using
ck
::
host_reduce
::
PosUnaryOpFn
;
using
ck
::
host_reduce
::
PreUnaryOpFn
;
// this->outLengths = to_int_vector(outDesc.GetLengths());
this
->
outStrides
=
outDesc
.
GetStrides
();
...
...
@@ -171,24 +144,24 @@ struct ReductionHost
invariant_dim_indexes
.
clear
();
get_all_indexes
<
NumInvariantDim
>
(
invariantLengths
,
invariant_dim_indexes
);
};
preUnaryOp
=
PreUnaryOpFn
<
AccDataType
,
ReduceOpId
>
(
divider
);
posUnaryOp
=
PosUnaryOpFn
<
AccDataType
,
ReduceOpId
>
(
divider
);
};
void
Run
(
float
alpha
,
const
InDataType
*
in_data
,
float
beta
,
OutDataType
*
out_data
,
IndexDataType
*
out_indices
)
IndexDataType
*
out_indices
,
InElementwiseOperation
in_elementwise_op
,
AccElementwiseOperation
acc_elementwise_op
)
{
if
constexpr
(
NeedIndices
)
if
constexpr
(
OutputIndex
)
{
RunImpl_with_index
(
alpha
,
in_data
,
beta
,
out_data
,
out_indices
);
RunImpl_with_index
(
alpha
,
in_data
,
beta
,
out_data
,
out_indices
,
in_elementwise_op
,
acc_elementwise_op
);
}
else
{
RunImpl_no_index
(
alpha
,
in_data
,
beta
,
out_data
);
RunImpl_no_index
(
alpha
,
in_data
,
beta
,
out_data
,
in_elementwise_op
,
acc_elementwise_op
);
};
};
...
...
@@ -196,20 +169,22 @@ struct ReductionHost
const
InDataType
*
in_data
,
float
beta
,
OutDataType
*
out_data
,
IndexDataType
*
out_indices
)
IndexDataType
*
out_indices
,
InElementwiseOperation
in_elementwise_op
,
AccElementwiseOperation
acc_elementwise_op
)
{
using
ck
::
float_equal_one
;
using
ck
::
float_equal_zero
;
using
ck
::
type_convert
;
using
ck
::
host_reduce
::
binop_with_index_and_nan_check
;
using
ck
::
host_reduce
::
ReduceOpFn2
;
using
ck
::
host_reduce
::
ReduceOpZeroVal
;
auto
opReduce2
=
ReduceOpFn2
<
AccDataType
,
ReduceOpId
>
();
using
Accumulation
=
ck
::
detail
::
AccumulateWithIndexAndNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
,
IndexDataType
>
;
if
constexpr
(
NumInvariantDim
==
0
)
{
AccDataType
accuVal
=
ReduceOp
Z
er
o
Val
<
AccDataType
,
ReduceOpId
>
();
AccDataType
accuVal
=
ReduceOper
ation
::
template
GetIdentity
Val
ue
<
AccDataType
>();
IndexDataType
accuIndex
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
reduce_dim_indexes
.
size
();
i
++
)
...
...
@@ -219,15 +194,14 @@ struct ReductionHost
auto
currVal
=
type_convert
<
AccDataType
>
(
in_data
[
offset_reduce
]);
preUnaryOp
(
currVal
);
in_elementwise_op
(
currVal
,
currVal
);
auto
currIndex
=
static_cast
<
IndexDataType
>
(
i
);
binop_with_index_and_nan_check
<
AccDataType
,
IndexDataType
,
PropagateNan
>
(
opReduce2
,
accuVal
,
currVal
,
accuIndex
,
currIndex
);
Accumulation
::
Calculate
(
accuVal
,
currVal
,
accuIndex
,
currIndex
);
};
posUnaryOp
(
accuVal
);
acc_elementwise_op
(
accuVal
,
accuVal
);
if
(
!
float_equal_one
{}(
alpha
))
accuVal
*=
type_convert
<
AccDataType
>
(
alpha
);
...
...
@@ -241,7 +215,7 @@ struct ReductionHost
else
{
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
AccDataType
accuVal
=
ReduceOp
Z
er
o
Val
<
AccDataType
,
ReduceOpId
>
();
AccDataType
accuVal
=
ReduceOper
ation
::
template
GetIdentity
Val
ue
<
AccDataType
>();
IndexDataType
accuIndex
=
0
;
auto
offset_invariant
=
...
...
@@ -255,15 +229,14 @@ struct ReductionHost
auto
currVal
=
type_convert
<
AccDataType
>
(
in_data
[
offset_invariant
+
offset_reduce
]);
preUnaryOp
(
currVal
);
in_elementwise_op
(
currVal
,
currVal
);
auto
currIndex
=
static_cast
<
IndexDataType
>
(
i
);
binop_with_index_and_nan_check
<
AccDataType
,
IndexDataType
,
PropagateNan
>
(
opReduce2
,
accuVal
,
currVal
,
accuIndex
,
currIndex
);
Accumulation
::
Calculate
(
accuVal
,
currVal
,
accuIndex
,
currIndex
);
};
posUnaryOp
(
accuVal
);
acc_elementwise_op
(
accuVal
,
accuVal
);
if
(
!
float_equal_one
{}(
alpha
))
accuVal
*=
type_convert
<
AccDataType
>
(
alpha
);
...
...
@@ -303,20 +276,23 @@ struct ReductionHost
};
};
void
RunImpl_no_index
(
float
alpha
,
const
InDataType
*
in_data
,
float
beta
,
OutDataType
*
out_data
)
void
RunImpl_no_index
(
float
alpha
,
const
InDataType
*
in_data
,
float
beta
,
OutDataType
*
out_data
,
InElementwiseOperation
in_elementwise_op
,
AccElementwiseOperation
acc_elementwise_op
)
{
using
ck
::
float_equal_one
;
using
ck
::
float_equal_zero
;
using
ck
::
type_convert
;
using
ck
::
host_reduce
::
binop_with_nan_check
;
using
ck
::
host_reduce
::
ReduceOpFn
;
using
ck
::
host_reduce
::
ReduceOpZeroVal
;
auto
opReduce
=
ReduceOpFn
<
AccDataType
,
ReduceOpId
>
();
using
Accumulation
=
ck
::
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
ReduceOperation
,
AccDataType
>
;
if
constexpr
(
NumInvariantDim
==
0
)
{
AccDataType
accuVal
=
ReduceOp
Z
er
o
Val
<
AccDataType
,
ReduceOpId
>
();
AccDataType
accuVal
=
ReduceOper
ation
::
template
GetIdentity
Val
ue
<
AccDataType
>();
for
(
const
auto
&
reduce_index
:
reduce_dim_indexes
)
{
...
...
@@ -325,12 +301,12 @@ struct ReductionHost
auto
currVal
=
type_convert
<
AccDataType
>
(
in_data
[
offset_reduce
]);
preUnaryOp
(
currVal
);
in_elementwise_op
(
currVal
,
currVal
);
binop_with_nan_check
<
AccDataType
,
PropagateNan
>
(
opReduce
,
accuVal
,
currVal
);
Accumulation
::
Calculate
(
accuVal
,
currVal
);
};
posUnaryOp
(
accuVal
);
acc_elementwise_op
(
accuVal
,
accuVal
);
if
(
!
float_equal_one
{}(
alpha
))
accuVal
*=
type_convert
<
AccDataType
>
(
alpha
);
...
...
@@ -343,7 +319,7 @@ struct ReductionHost
else
{
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
AccDataType
accuVal
=
ReduceOp
Z
er
o
Val
<
AccDataType
,
ReduceOpId
>
();
AccDataType
accuVal
=
ReduceOper
ation
::
template
GetIdentity
Val
ue
<
AccDataType
>();
auto
offset_invariant
=
get_offset_from_index
<
NumInvariantDim
>
(
invariantStrides
,
invariant_index
);
...
...
@@ -356,12 +332,12 @@ struct ReductionHost
auto
currVal
=
type_convert
<
AccDataType
>
(
in_data
[
offset_invariant
+
offset_reduce
]);
preUnaryOp
(
currVal
);
in_elementwise_op
(
currVal
,
currVal
);
binop_with_nan_check
<
AccDataType
,
PropagateNan
>
(
opReduce
,
accuVal
,
currVal
);
Accumulation
::
Calculate
(
accuVal
,
currVal
);
};
posUnaryOp
(
accuVal
);
acc_elementwise_op
(
accuVal
,
accuVal
);
if
(
!
float_equal_one
{}(
alpha
))
accuVal
*=
type_convert
<
AccDataType
>
(
alpha
);
...
...
@@ -400,5 +376,3 @@ struct ReductionHost
};
};
};
#endif
library/include/ck/library/host_tensor/host_tensor.hpp
View file @
7a3b49e5
#ifndef HOST_TENSOR_HPP
#define HOST_TENSOR_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <thread>
#include <vector>
...
...
@@ -8,7 +10,8 @@
#include <utility>
#include <cassert>
#include <iostream>
#include "data_type.hpp"
#include "ck/utility/data_type.hpp"
template
<
typename
Range
>
std
::
ostream
&
LogRange
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
)
...
...
@@ -107,6 +110,11 @@ struct HostTensorDescriptor
return
std
::
inner_product
(
iss
.
begin
(),
iss
.
end
(),
mStrides
.
begin
(),
std
::
size_t
{
0
});
}
std
::
size_t
GetOffsetFromMultiIndex
(
std
::
vector
<
std
::
size_t
>
iss
)
const
{
return
std
::
inner_product
(
iss
.
begin
(),
iss
.
end
(),
mStrides
.
begin
(),
std
::
size_t
{
0
});
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
HostTensorDescriptor
&
desc
);
private:
...
...
@@ -212,6 +220,54 @@ struct Tensor
Tensor
(
const
HostTensorDescriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
mDesc
.
GetElementSpace
())
{}
Tensor
(
const
Tensor
&
other
)
:
mDesc
(
other
.
mDesc
),
mData
(
other
.
mData
)
{}
template
<
typename
F
>
void
ForEach_impl
(
F
&&
f
,
std
::
vector
<
size_t
>&
idx
,
size_t
rank
)
{
if
(
rank
==
mDesc
.
GetNumOfDimension
())
{
f
(
*
this
,
idx
);
return
;
}
// else
for
(
size_t
i
=
0
;
i
<
mDesc
.
GetLengths
()[
rank
];
i
++
)
{
idx
[
rank
]
=
i
;
ForEach_impl
(
std
::
forward
<
F
>
(
f
),
idx
,
rank
+
1
);
}
}
template
<
typename
F
>
void
ForEach
(
F
&&
f
)
{
std
::
vector
<
size_t
>
idx
(
mDesc
.
GetNumOfDimension
(),
0
);
ForEach_impl
(
std
::
forward
<
F
>
(
f
),
idx
,
size_t
(
0
));
}
template
<
typename
F
>
void
ForEach_impl
(
const
F
&&
f
,
std
::
vector
<
size_t
>&
idx
,
size_t
rank
)
const
{
if
(
rank
==
mDesc
.
GetNumOfDimension
())
{
f
(
*
this
,
idx
);
return
;
}
// else
for
(
size_t
i
=
0
;
i
<
mDesc
.
GetLengths
()[
rank
];
i
++
)
{
idx
[
rank
]
=
i
;
ForEach_impl
(
std
::
forward
<
const
F
>
(
f
),
idx
,
rank
+
1
);
}
}
template
<
typename
F
>
void
ForEach
(
const
F
&&
f
)
const
{
std
::
vector
<
size_t
>
idx
(
mDesc
.
GetNumOfDimension
(),
0
);
ForEach_impl
(
std
::
forward
<
const
F
>
(
f
),
idx
,
size_t
(
0
));
}
template
<
typename
G
>
void
GenerateTensorValue
(
G
g
,
std
::
size_t
num_thread
=
1
)
{
...
...
@@ -272,6 +328,16 @@ struct Tensor
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
}
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
const
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
const
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
typename
std
::
vector
<
T
>::
iterator
begin
()
{
return
mData
.
begin
();
}
typename
std
::
vector
<
T
>::
iterator
end
()
{
return
mData
.
end
();
}
...
...
@@ -285,7 +351,8 @@ struct Tensor
};
template
<
typename
X
>
HostTensorDescriptor
::
HostTensorDescriptor
(
const
std
::
vector
<
X
>&
lens
)
:
mLens
(
lens
)
HostTensorDescriptor
::
HostTensorDescriptor
(
const
std
::
vector
<
X
>&
lens
)
:
mLens
(
lens
.
begin
(),
lens
.
end
())
{
this
->
CalculateStrides
();
}
...
...
@@ -293,7 +360,7 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector<X>& lens) : mLens(l
template
<
typename
X
,
typename
Y
>
HostTensorDescriptor
::
HostTensorDescriptor
(
const
std
::
vector
<
X
>&
lens
,
const
std
::
vector
<
Y
>&
strides
)
:
mLens
(
lens
),
mStrides
(
strides
)
:
mLens
(
lens
.
begin
(),
lens
.
end
()),
mStrides
(
strides
.
begin
(),
strides
.
end
()
)
{
}
...
...
@@ -349,5 +416,3 @@ float check_error(const Tensor<T>& ref, const Tensor<T>& result)
return
linf_error
;
}
#endif
library/include/ck/library/host_tensor/host_tensor_generator.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cmath>
#include <numeric>
#include "c
onfig
.hpp"
#include "c
k/ck
.hpp"
template
<
typename
T
>
struct
GeneratorTensor_0
...
...
@@ -18,12 +21,12 @@ struct GeneratorTensor_0
template
<
typename
T
>
struct
GeneratorTensor_1
{
int
value
=
1
;
T
value
=
1
;
template
<
typename
...
Is
>
T
operator
()(
Is
...)
{
return
ck
::
type_convert
<
T
>
(
value
)
;
return
value
;
}
};
...
...
library/include/ck/library/obselete_driver_offline/debug.hpp
deleted
100644 → 0
View file @
e07b3d8e
#ifndef DEBUG_HPP
#define DEBUG_HPP
namespace
debug
{
namespace
debug_driver_gemm_xdlops_v2r3
{
// these vars are on host, they control block_id to C matrix tile idx (m0, n0) mapping
static
ck
::
index_t
M01
=
1
;
static
ck
::
index_t
N01
=
1
;
}
// namespace debug_driver_gemm_xdlops_v2r3
}
// namespace debug
#endif
library/include/ck/library/obselete_driver_offline/device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp
deleted
100644 → 0
View file @
e07b3d8e
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1.hpp"
template
<
typename
TInWei
,
typename
TAcc
,
typename
TOut
,
ck
::
ActivTypeEnum
activ_type
,
typename
InLengths
,
typename
WeiLengths
,
typename
AddLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_add_forward_implicit_gemm_v5r1_dlops_nc0hwc1_kc0yxc1_nk0hwk1
(
const
InLengths
&
in_n_c0_hi_wi_c1_lengths
,
const
WeiLengths
&
wei_k_c0_y_x_c1_lengths
,
const
AddLengths
&
add_n_k0_hox2_wox2_k1_lengths
,
const
OutLengths
&
out_n_k0_ho_wo_k1_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TInWei
>&
in_n_c0_hi_wi_c1
,
const
Tensor
<
TInWei
>&
wei_k_c0_y_x_c1
,
const
Tensor
<
TOut
>&
bias_k0_k1
,
const
Tensor
<
TOut
>&
add_n_k0_hox2_wox2_k1
,
Tensor
<
TOut
>&
add_n_k0_hox2_wox2_k1_out
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I4
=
Number
<
4
>
{};
const
auto
N
=
out_n_k0_ho_wo_k1_lengths
[
I0
];
const
auto
K0
=
out_n_k0_ho_wo_k1_lengths
[
I1
];
const
auto
Ho
=
out_n_k0_ho_wo_k1_lengths
[
I2
];
const
auto
Wo
=
out_n_k0_ho_wo_k1_lengths
[
I3
];
const
auto
K1
=
out_n_k0_ho_wo_k1_lengths
[
I4
];
const
auto
C0
=
in_n_c0_hi_wi_c1_lengths
[
I1
];
const
auto
Hi
=
in_n_c0_hi_wi_c1_lengths
[
I2
];
const
auto
Wi
=
in_n_c0_hi_wi_c1_lengths
[
I3
];
const
auto
C1
=
in_n_c0_hi_wi_c1_lengths
[
I4
];
const
auto
K
=
wei_k_c0_y_x_c1_lengths
[
I0
];
const
auto
Y
=
wei_k_c0_y_x_c1_lengths
[
I2
];
const
auto
X
=
wei_k_c0_y_x_c1_lengths
[
I3
];
const
auto
Hox2
=
add_n_k0_hox2_wox2_k1_lengths
[
I2
];
const
auto
Wox2
=
add_n_k0_hox2_wox2_k1_lengths
[
I3
];
DeviceMem
in_n_c0_hi_wi_c1_device_buf
(
sizeof
(
TInWei
)
*
in_n_c0_hi_wi_c1
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c0_y_x_c1_device_buf
(
sizeof
(
TInWei
)
*
wei_k_c0_y_x_c1
.
mDesc
.
GetElementSpace
());
DeviceMem
bias_k0_k1_device_buf
(
sizeof
(
TOut
)
*
bias_k0_k1
.
mDesc
.
GetElementSpace
());
DeviceMem
add_n_k0_hox2_wox2_k1_device_buf
(
sizeof
(
TOut
)
*
add_n_k0_hox2_wox2_k1
.
mDesc
.
GetElementSpace
());
in_n_c0_hi_wi_c1_device_buf
.
ToDevice
(
in_n_c0_hi_wi_c1
.
mData
.
data
());
wei_k_c0_y_x_c1_device_buf
.
ToDevice
(
wei_k_c0_y_x_c1
.
mData
.
data
());
bias_k0_k1_device_buf
.
ToDevice
(
bias_k0_k1
.
mData
.
data
());
add_n_k0_hox2_wox2_k1_device_buf
.
ToDevice
(
add_n_k0_hox2_wox2_k1
.
mData
.
data
());
constexpr
index_t
InWeiVectorSize
=
8
;
if
(
C1
%
InWeiVectorSize
!=
0
)
{
throw
std
::
runtime_error
(
"wrong! C1 cannot be divided by InWeiVectorSize"
);
}
#if 0
constexpr index_t BlockSize = 256;
constexpr index_t KPerBlock = 32;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 64;
constexpr index_t E1 = C0 * 9;
constexpr index_t E2 = 1;
constexpr index_t E1PerBlock = C0;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = 1;
using ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2 = Sequence<1, 9, 1, E2>;
using ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2 = Sequence<1, E1PerBlock, KPerBlock, 1>;
constexpr index_t ABlockTransferSrcScalarPerVector_E2 = E2;
constexpr index_t ABlockTransferDstScalarPerVector_E2 = E2;
constexpr index_t BThreadTransferSrcScalarPerVector_E2 = E2;
constexpr index_t CThreadTransferDstScalarPerVector_K = K1;
#elif
1
constexpr
auto
BlockSize
=
64
;
constexpr
auto
KPerBlock
=
8
;
constexpr
auto
HoPerBlock
=
8
;
constexpr
auto
WoPerBlock
=
32
;
constexpr
auto
E1
=
2
*
9
;
constexpr
auto
E2
=
1
;
constexpr
auto
K2
=
2
;
constexpr
auto
E1PerBlock
=
2
;
constexpr
auto
KPerThread
=
KPerBlock
;
constexpr
auto
HoPerThread
=
2
;
constexpr
auto
WoPerThread
=
2
;
constexpr
auto
EPerThread
=
1
;
using
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2
=
Sequence
<
1
,
9
,
1
,
1
,
E2
>
;
using
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2
=
Sequence
<
1
,
E1PerBlock
,
1
,
KPerBlock
,
1
>
;
constexpr
auto
ABlockTransferSrcScalarPerVector_E2
=
E2
;
constexpr
auto
ABlockTransferDstScalarPerVector_E2
=
E2
;
constexpr
auto
BThreadTransferSrcScalarPerVector_E2
=
E2
;
constexpr
auto
CThreadTransferDstScalarPerVector_K
=
InWeiVectorSize
;
#endif
const
auto
in_n_c0_hi_wi_c1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
C0
,
Hi
,
Wi
,
E2
));
const
auto
wei_k_c0_y_x_c1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C0
,
Y
,
X
,
E2
));
const
auto
add_n_k0_hox2_wox2_k1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Hox2
,
Wox2
,
K1
));
const
auto
out_n_k0_ho_wo_k1_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
K0
,
Ho
,
Wo
,
K1
));
constexpr
auto
conv_driver
=
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0hwk1_add
<
BlockSize
,
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
,
TAcc
,
TOut
,
E1
,
E2
,
K2
,
KPerBlock
,
HoPerBlock
,
WoPerBlock
,
E1PerBlock
,
KPerThread
,
HoPerThread
,
WoPerThread
,
EPerThread
,
ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2
,
ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2
,
ABlockTransferSrcScalarPerVector_E2
,
ABlockTransferDstScalarPerVector_E2
,
BThreadTransferSrcScalarPerVector_E2
,
CThreadTransferDstScalarPerVector_K
,
activ_type
>
{};
std
::
cerr
<<
"conv_bias_activ_resize_add_input_"
<<
"n"
<<
N
<<
"c"
<<
C0
<<
"h"
<<
Hi
<<
"w"
<<
Wi
<<
"c"
<<
C1
<<
"_filter_k"
<<
K
<<
"c"
<<
C0
<<
"y"
<<
Y
<<
"x"
<<
X
<<
"c"
<<
C1
<<
"_addout_n"
<<
N
<<
"k"
<<
K0
<<
"h"
<<
Ho
*
2
<<
"w"
<<
Wo
*
2
<<
"k"
<<
K1
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
5
;
i
++
)
{
const
auto
ave_time
=
conv_driver
.
Run
(
wei_k_c0_y_x_c1_desc
,
in_n_c0_hi_wi_c1_desc
,
out_n_k0_ho_wo_k1_desc
,
add_n_k0_hox2_wox2_k1_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
wei_k_c0_y_x_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c0_hi_wi_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
bias_k0_k1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
add_n_k0_hox2_wox2_k1_device_buf
.
GetDeviceBuffer
()),
nrepeat
);
{
float
perf
=
static_cast
<
float
>
(
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C0
*
C1
*
Y
*
X
)
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
}
add_n_k0_hox2_wox2_k1_device_buf
.
ToDevice
(
add_n_k0_hox2_wox2_k1
.
mData
.
data
());
conv_driver
.
Run
(
wei_k_c0_y_x_c1_desc
,
in_n_c0_hi_wi_c1_desc
,
out_n_k0_ho_wo_k1_desc
,
add_n_k0_hox2_wox2_k1_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
wei_k_c0_y_x_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
in_n_c0_hi_wi_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
bias_k0_k1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
add_n_k0_hox2_wox2_k1_device_buf
.
GetDeviceBuffer
()),
0
);
add_n_k0_hox2_wox2_k1_device_buf
.
FromDevice
(
add_n_k0_hox2_wox2_k1_out
.
mData
.
data
());
}
library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp
deleted
100644 → 0
View file @
e07b3d8e
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
#include "debug.hpp"
template
<
typename
TInWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Tensor
<
TInWei
>&
in_n_hi_wi_c
,
const
Tensor
<
TInWei
>&
wei_k_y_x_c
,
const
Tensor
<
TOut
>&
out_n_ho_wo_k
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TInWei
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TInWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_ho_wo_k_device_buf
(
sizeof
(
TOut
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
());
in_n_hi_wi_c_device_buf
.
ToDevice
(
in_n_hi_wi_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerXDL = 32;
constexpr index_t GemmNPerXDL = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif
0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
2
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
4
;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerXDL
=
32
;
constexpr
index_t
GemmNPerXDL
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmM
=
2
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
4
;
#endif
const
auto
descs
=
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk
(
wei_k_y_x_c_desc
,
out_n_ho_wo_k_desc
,
in_n_hi_wi_c_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
I0
,
I0
,
Number
<
GemmK1
>
{});
const
auto
wei_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
out_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
in_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmM
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: GemmK1
// clang-format off
constexpr
auto
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
//clang-format on
constexpr
auto
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{};
constexpr
auto
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
TInWei
,
TAcc
,
TOut
,
InMemoryDataOperationEnum
::
Set
,
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
in_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerXDL
,
GemmNPerXDL
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
2
,
0
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
GemmABlockTransferSrcScalarPerVector_GemmM
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
GemmBBlockTransferSrcScalarPerVector_GemmK1
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
1
,
3
,
7
,
0
,
2
,
4
,
5
,
6
>
,
6
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
,
// CAccessOrderMRepeatNRepeat
false
,
// ABlockLdsExtraM
false
// BBlockLdsExtraN
>
(
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
wei_gemmk0_gemmm_gemmk1_grid_desc
,
out_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks
,
out_gemmk0_gemmn_gemmk1_grid_step_hacks
,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
{
const
auto
N
=
out_n_ho_wo_k_lengths
[
I0
];
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
const
auto
Y
=
wei_k_y_x_c_lengths
[
I1
];
const
auto
X
=
wei_k_y_x_c_lengths
[
I2
];
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
}
// copy result back to host
in_n_hi_wi_c_device_buf
.
FromDevice
(
in_n_hi_wi_c
.
mData
.
data
());
}
library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp
deleted
100644 → 0
View file @
e07b3d8e
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
TInWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
Tensor
<
TInWei
>&
in_n_hi_wi_c
,
const
Tensor
<
TInWei
>&
wei_k_y_x_c
,
const
Tensor
<
TOut
>&
out_n_ho_wo_k
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TInWei
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TInWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_ho_wo_k_device_buf
(
sizeof
(
TOut
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
());
in_n_hi_wi_c_device_buf
.
ToDevice
(
in_n_hi_wi_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
4
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
>
{},
// 1+: gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: gemmk0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
>
{},
// 1-: gemmm
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-:
// gemmk1
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: gemmk0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 2+: gemmk1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: Gemmk0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: Gemmn
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 2-: Gemmk1
// clang-format off
constexpr
auto
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
// clang-format on
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
2
,
0
>
{};
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
const
auto
ConvStrideH
=
conv_strides
[
I0
];
const
auto
ConvStrideW
=
conv_strides
[
I1
];
const
auto
ConvDilationH
=
conv_dilations
[
I0
];
const
auto
ConvDilationW
=
conv_dilations
[
I1
];
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
float
ave_time
=
0
;
for
(
index_t
i_ytilde
=
0
;
i_ytilde
<
YTilde
;
++
i_ytilde
)
{
for
(
index_t
i_xtilde
=
0
;
i_xtilde
<
XTilde
;
++
i_xtilde
)
{
const
auto
descs
=
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk
(
out_n_ho_wo_k_desc
,
wei_k_y_x_c_desc
,
in_n_hi_wi_c_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
i_ytilde
,
i_xtilde
,
Number
<
GemmK1
>
{});
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
in_gemmm_gemmn_grid_desc
=
descs
[
I2
];
const
auto
GemmK0
=
out_gemmk0_gemmm_gemmk1_grid_desc
.
GetLength
(
I0
);
if
(
GemmK0
!=
0
)
{
ave_time
+=
driver_gemm_xdlops_v2r3
<
BlockSize
,
TInWei
,
TAcc
,
TOut
,
InMemoryDataOperationEnum
::
Set
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
in_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
GemmABlockTransferSrcScalarPerVector_GemmK1
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
2
,
0
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
#if 0
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
#else
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
#endif
7
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
true
,
// CAccessOrderMRepeatNRepeat
false
,
// ABlockLdsExtraM
false
// BBlockLdsExtraN
>
(
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
}
}
}
{
const
auto
N
=
out_n_ho_wo_k_lengths
[
I0
];
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
const
auto
Y
=
wei_k_y_x_c_lengths
[
I1
];
const
auto
X
=
wei_k_y_x_c_lengths
[
I2
];
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
}
// copy result back to host
in_n_hi_wi_c_device_buf
.
FromDevice
(
in_n_hi_wi_c
.
mData
.
data
());
}
library/include/ck/library/obselete_driver_offline/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp
deleted
100644 → 0
View file @
e07b3d8e
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
TInWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1
(
const
InLengths
&
in_n_hi_wi_c_lengths
,
const
WeiLengths
&
wei_k_y_x_c_lengths
,
const
OutLengths
&
out_n_ho_wo_k_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
,
const
InLeftPads
&
,
const
InRightPads
&
,
Tensor
<
TInWei
>&
in_n_hi_wi_c
,
const
Tensor
<
TInWei
>&
wei_k_y_x_c
,
const
Tensor
<
TOut
>&
out_n_ho_wo_k
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
DeviceMem
in_n_hi_wi_c_device_buf
(
sizeof
(
TInWei
)
*
in_n_hi_wi_c
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_y_x_c_device_buf
(
sizeof
(
TInWei
)
*
wei_k_y_x_c
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_ho_wo_k_device_buf
(
sizeof
(
TOut
)
*
out_n_ho_wo_k
.
mDesc
.
GetElementSpace
());
in_n_hi_wi_c_device_buf
.
ToDevice
(
in_n_hi_wi_c
.
mData
.
data
());
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor_packed
(
in_n_hi_wi_c_lengths
);
const
auto
wei_k_y_x_c_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_y_x_c_lengths
);
const
auto
out_n_ho_wo_k_desc
=
make_naive_tensor_descriptor_packed
(
out_n_ho_wo_k_lengths
);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif
0
// [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
4
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
4
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
256
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
2
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#elif 0
// [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
8
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: gemmk0
Sequence
<
0
,
0
,
0
>
{},
// 1+: gemmm
Sequence
<
0
,
0
,
0
>
{}),
// 2+: gemmk1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: gemmk0
Sequence
<
0
,
0
,
0
>
{},
// 1-: gemmm
Sequence
<
0
,
0
,
0
>
{}));
// 2-: gemmk1
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0+: gemmk0
Sequence
<
0
,
0
,
0
>
{},
// 1+: gemmn
Sequence
<
0
,
0
,
0
>
{}),
// 2+: gemmk1
make_tuple
(
Sequence
<
0
,
0
,
0
>
{},
// 0-: Gemmk0
Sequence
<
0
,
0
,
0
>
{},
// 1-: Gemmn
Sequence
<
0
,
0
,
0
>
{}));
// 2-: Gemmk1
// clang-format off
constexpr
auto
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
// clang-format on
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
constexpr
auto
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
const
auto
descs
=
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1
(
out_n_ho_wo_k_desc
,
wei_k_y_x_c_desc
,
in_n_hi_wi_c_desc
,
conv_strides
,
Number
<
GemmK1
>
{});
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
in_gemmm_gemmn_grid_desc
=
descs
[
I2
];
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
TInWei
,
TAcc
,
TOut
,
InMemoryDataOperationEnum
::
Set
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
in_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
GemmABlockTransferSrcScalarPerVector_GemmK1
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
2
,
0
,
1
>
,
Sequence
<
0
,
2
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
#if 0
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
#else
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
#endif
7
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
true
,
// CAccessOrderMRepeatNRepeat
false
,
// ABlockLdsExtraM
false
// BBlockLdsExtraN
>
(
static_cast
<
TOut
*>
(
out_n_ho_wo_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
wei_k_y_x_c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TInWei
*>
(
in_n_hi_wi_c_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks
,
in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
{
const
auto
N
=
out_n_ho_wo_k_lengths
[
I0
];
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
const
auto
Y
=
wei_k_y_x_c_lengths
[
I1
];
const
auto
X
=
wei_k_y_x_c_lengths
[
I2
];
float
perf
=
static_cast
<
float
>
((
std
::
size_t
(
2
)
*
N
*
K
*
Ho
*
Wo
*
C
*
Y
*
X
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
}
// copy result back to host
in_n_hi_wi_c_device_buf
.
FromDevice
(
in_n_hi_wi_c
.
mData
.
data
());
}
library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
e07b3d8e
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp"
#include "driver_gemm_xdlops_v2r4.hpp"
template
<
typename
TIn
,
typename
TWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
typename
GridSizeType
>
void
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw
(
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
const
OutLengths
&
out_n_k_ho_wo_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TIn
>&
in_n_c_hi_wi
,
Tensor
<
TWei
>&
wei_k_c_y_x
,
const
Tensor
<
TOut
>&
out_n_k_ho_wo
,
GridSizeType
desired_grid_size
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TIn
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_k_ho_wo_device_buf
(
sizeof
(
TOut
)
*
out_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
in_n_c_hi_wi_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
const
auto
in_n_c_hi_wi_desc
=
make_naive_tensor_descriptor_packed
(
in_n_c_hi_wi_lengths
);
const
auto
wei_k_c_y_x_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_c_y_x_lengths
);
const
auto
out_n_k_ho_wo_desc
=
make_naive_tensor_descriptor_packed
(
out_n_k_ho_wo_lengths
);
#if 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp32
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
2
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
1
,
2
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
64
,
1
>
;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
N
=
in_n_c_hi_wi_desc
.
GetLength
(
I0
);
const
auto
C
=
in_n_c_hi_wi_desc
.
GetLength
(
I1
);
const
auto
K
=
out_n_k_ho_wo_desc
.
GetLength
(
I1
);
const
auto
Ho
=
out_n_k_ho_wo_desc
.
GetLength
(
I2
);
const
auto
Wo
=
out_n_k_ho_wo_desc
.
GetLength
(
I3
);
const
auto
Y
=
wei_k_c_y_x_desc
.
GetLength
(
I2
);
const
auto
X
=
wei_k_c_y_x_desc
.
GetLength
(
I3
);
const
auto
GemmM
=
K
;
const
auto
GemmN
=
Y
*
X
*
C
;
const
auto
GemmKTotal
=
N
*
Ho
*
Wo
;
const
auto
GridMN
=
GemmM
*
GemmN
/
(
GemmMPerBlock
*
GemmNPerBlock
);
const
index_t
GemmKBatch
=
std
::
max
(
desired_grid_size
/
GridMN
,
1
);
const
index_t
GemmK0
=
math
::
integer_divide_ceil
(
GemmKTotal
,
GemmK1
*
GemmKPerBlock
*
GemmKBatch
)
*
GemmKPerBlock
;
const
index_t
GemmKPad
=
GemmKBatch
*
GemmK0
*
GemmK1
;
std
::
cout
<<
"GemmKTotal: "
<<
GemmKTotal
<<
" GrideSizeMN: "
<<
GridMN
<<
" GemmKBatch: "
<<
GemmKBatch
<<
" GemmK0: "
<<
GemmK0
<<
" gemmKPad: "
<<
GemmKPad
<<
std
::
endl
;
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad
(
wei_k_c_y_x_desc
,
in_n_c_hi_wi_desc
,
out_n_k_ho_wo_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{},
GemmKBatch
,
GemmKPad
);
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
wei_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmB
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmM
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{}),
// 3+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 0-: GemB
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmM
Sequence
<
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 0+: GemmB
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{}),
// 3+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 0-: GemmB
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
0
>
{}));
// 3-: GemmK1
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
1
,
0
,
0
,
0
,
0
>
{};
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
,
0
,
0
>
{};
const
auto
driver_gemm_xdlops
=
driver_gemm_xdlops_v2r4
<
BlockSize
,
TIn
,
TAcc
,
TWei
,
InMemoryDataOperationEnum
::
AtomicAdd
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
wei_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
3
,
GemmABlockTransferSrcScalarPerVector_GemmK1
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
3
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
3
,
0
,
1
,
2
,
7
,
5
,
4
,
6
>
,
7
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
,
true
,
true
>
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TIn
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
driver_gemm_xdlops
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TIn
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
0
);
// copy result back to host
wei_k_c_y_x_device_buf
.
FromDevice
(
wei_k_c_y_x
.
mData
.
data
());
}
library/include/ck/library/obselete_driver_offline/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp
deleted
100644 → 0
View file @
e07b3d8e
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template
<
typename
TIn
,
typename
TWei
,
typename
TAcc
,
typename
TOut
,
typename
InLengths
,
typename
WeiLengths
,
typename
OutLengths
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
void
device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
(
const
InLengths
&
in_n_c_hi_wi_lengths
,
const
WeiLengths
&
wei_k_c_y_x_lengths
,
const
OutLengths
&
out_n_k_ho_wo_lengths
,
const
ConvStrides
&
conv_strides
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
Tensor
<
TIn
>&
in_n_c_hi_wi
,
Tensor
<
TWei
>&
wei_k_c_y_x
,
const
Tensor
<
TOut
>&
out_n_k_ho_wo
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
std
::
cout
<<
__func__
<<
std
::
endl
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
DeviceMem
in_n_c_hi_wi_device_buf
(
sizeof
(
TIn
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_k_c_y_x_device_buf
(
sizeof
(
TWei
)
*
wei_k_c_y_x
.
mDesc
.
GetElementSpace
());
DeviceMem
out_n_k_ho_wo_device_buf
(
sizeof
(
TOut
)
*
out_n_k_ho_wo
.
mDesc
.
GetElementSpace
());
in_n_c_hi_wi_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
wei_k_c_y_x_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
out_n_k_ho_wo_device_buf
.
ToDevice
(
out_n_k_ho_wo
.
mData
.
data
());
const
auto
in_n_c_hi_wi_desc
=
make_naive_tensor_descriptor_packed
(
in_n_c_hi_wi_lengths
);
const
auto
wei_k_c_y_x_desc
=
make_naive_tensor_descriptor_packed
(
wei_k_c_y_x_lengths
);
const
auto
out_n_k_ho_wo_desc
=
make_naive_tensor_descriptor_packed
(
out_n_k_ho_wo_lengths
);
#if 0
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif
1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
256
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerWave
=
32
;
constexpr
index_t
GemmNPerWave
=
32
;
constexpr
index_t
GemmK1
=
8
;
constexpr
index_t
MRepeat
=
4
;
constexpr
index_t
NRepeat
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
1
,
4
,
8
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
// using vector load 4, so config's wo*ho must be a multiple of 4
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK1
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmK1
=
4
;
using
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
1
,
2
,
8
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
=
Sequence
<
4
,
64
,
1
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmK1
=
8
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector
=
1
;
#endif
const
auto
descs
=
transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad
(
wei_k_c_y_x_desc
,
in_n_c_hi_wi_desc
,
out_n_k_ho_wo_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
Number
<
GemmK1
>
{});
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
descs
[
I0
];
const
auto
in_gemmk0_gemmn_gemmk1_grid_desc
=
descs
[
I1
];
const
auto
wei_gemmm_gemmn_grid_desc
=
descs
[
I2
];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1+: GemmM
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
// 1-: GemmM
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{},
// 0+: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
0
>
{},
// 1+: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
>
{}),
// 2+: GemmK1
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{},
// 0-: GemmK0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
,
0
>
{},
// 1-: GemmN
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
0
,
0
>
{}));
// 2-: GemmK1
constexpr
auto
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0+: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1+: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2+: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3+: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4+: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5+: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6+: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}),
// 7+: N2
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 0-: M0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 1-: N0
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 2-: M1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 3-: N1
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 4-: M2
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 5-: M3
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{},
// 6-: M4
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
>
{}));
// 7-: N2
constexpr
auto
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
1
,
0
,
0
>
{};
constexpr
auto
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
=
Sequence
<
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
2
,
1
,
0
,
0
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
float
ave_time
=
driver_gemm_xdlops_v2r3
<
BlockSize
,
TIn
,
TAcc
,
TWei
,
InMemoryDataOperationEnum
::
Set
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_desc
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_desc
),
decltype
(
wei_gemmm_gemmn_grid_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmK1
,
MRepeat
,
NRepeat
,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1
,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
GemmABlockTransferSrcScalarPerVector_GemmK1
,
GemmABlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1
,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmK1
,
false
,
// don't move back src coordinate after threadwise copy
Sequence
<
3
,
0
,
1
,
2
,
7
,
5
,
4
,
6
>
,
7
,
GemmCThreadTransferDstScalarPerVector
,
decltype
(
out_gemmk0_gemmm_gemmk1_grid_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_step_hacks
),
decltype
(
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
),
decltype
(
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
),
decltype
(
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
),
false
,
// CAccessOrderMRepeatNRepeat
true
,
// ABlockLdsExtraM
true
// BBlockLdsExtraN
>
(
static_cast
<
TOut
*>
(
out_n_k_ho_wo_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TIn
*>
(
in_n_c_hi_wi_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TWei
*>
(
wei_k_c_y_x_device_buf
.
GetDeviceBuffer
()),
out_gemmk0_gemmm_gemmk1_grid_desc
,
in_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmm_gemmn_grid_desc
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
M01
,
debug
::
debug_driver_gemm_xdlops_v2r3
::
N01
,
out_gemmk0_gemmm_gemmk1_grid_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_step_hacks
,
wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks
,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks
,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks
,
nrepeat
);
float
perf
=
static_cast
<
float
>
(
calculate_convolution_flops
(
in_n_c_hi_wi_desc
,
wei_k_c_y_x_desc
,
out_n_k_ho_wo_desc
))
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
// copy result back to host
wei_k_c_y_x_device_buf
.
FromDevice
(
wei_k_c_y_x
.
mData
.
data
());
}
Prev
1
…
9
10
11
12
13
14
15
16
17
…
30
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment