Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1b15b21a
"tests/vscode:/vscode.git/clone" did not exist on "5fd3dca5f377126b73a9af8aaf7a6291951d201c"
Commit
1b15b21a
authored
Dec 13, 2021
by
Chao Liu
Browse files
update static_tensor for dealing with invalid element
parent
2fd5e6ae
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
240 additions
and
256 deletions
+240
-256
composable_kernel/include/tensor_description/static_tensor.hpp
...sable_kernel/include/tensor_description/static_tensor.hpp
+20
-15
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
...lude/tensor_operation/blockwise_tensor_slice_transfer.hpp
+10
-0
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
+144
-77
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
+6
-1
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+1
-1
composable_kernel/include/utility/config.hpp
composable_kernel/include/utility/config.hpp
+5
-1
composable_kernel/include/utility/utility.hpp
composable_kernel/include/utility/utility.hpp
+4
-0
device_operation/include/device_conv2d_fwd_xdl_output_shuffle_nhwc_kyxc_nhwk.hpp
...e/device_conv2d_fwd_xdl_output_shuffle_nhwc_kyxc_nhwk.hpp
+50
-10
device_operation/include/element_wise_operation.hpp
device_operation/include/element_wise_operation.hpp
+0
-151
No files found.
composable_kernel/include/tensor_description/static_tensor.hpp
View file @
1b15b21a
#ifndef CK_STATIC_TENSOR_HPP
#ifndef CK_STATIC_TENSOR_HPP
#define CK_STATIC_TENSOR_HPP
#define CK_STATIC_TENSOR_HPP
#include "ignore.hpp"
namespace
ck
{
namespace
ck
{
// StaticTensor for Scalar
// StaticTensor for Scalar
...
@@ -17,10 +15,10 @@ struct StaticTensor
...
@@ -17,10 +15,10 @@ struct StaticTensor
static
constexpr
index_t
ndim_
=
TensorDesc
::
GetNumOfDimension
();
static
constexpr
index_t
ndim_
=
TensorDesc
::
GetNumOfDimension
();
static
constexpr
index_t
element_space_size_
=
desc_
.
GetElementSpaceSize
();
static
constexpr
index_t
element_space_size_
=
desc_
.
GetElementSpaceSize
();
__host__
__device__
constexpr
StaticTensor
()
:
invalid_element_value_
{
0
}
{}
__host__
__device__
constexpr
StaticTensor
()
:
invalid_element_
scalar_
value_
{
0
}
{}
__host__
__device__
constexpr
StaticTensor
(
T
invalid_element_value
)
__host__
__device__
constexpr
StaticTensor
(
T
invalid_element_value
)
:
invalid_element_value_
{
invalid_element_value
}
:
invalid_element_
scalar_
value_
{
invalid_element_value
}
{
{
}
}
...
@@ -44,11 +42,11 @@ struct StaticTensor
...
@@ -44,11 +42,11 @@ struct StaticTensor
{
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
{
return
T
{
0
}
;
return
zero_scalar_value_
;
}
}
else
else
{
{
return
invalid_element_value_
;
return
invalid_element_
scalar_
value_
;
}
}
}
}
}
}
...
@@ -71,12 +69,14 @@ struct StaticTensor
...
@@ -71,12 +69,14 @@ struct StaticTensor
}
}
else
else
{
{
return
ignore
;
return
ignore
d_element_scalar_
;
}
}
}
}
StaticBuffer
<
AddressSpace
,
T
,
element_space_size_
,
true
>
data_
;
StaticBuffer
<
AddressSpace
,
T
,
element_space_size_
,
true
>
data_
;
T
invalid_element_value_
=
T
{
0
};
static
constexpr
T
zero_scalar_value_
=
T
{
0
};
const
T
invalid_element_scalar_value_
;
T
ignored_element_scalar_
;
};
};
// StaticTensor for vector
// StaticTensor for vector
...
@@ -97,10 +97,13 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -97,10 +97,13 @@ struct StaticTensorTupleOfVectorBuffer
using
V
=
vector_type
<
S
,
ScalarPerVector
>
;
using
V
=
vector_type
<
S
,
ScalarPerVector
>
;
__host__
__device__
constexpr
StaticTensorTupleOfVectorBuffer
()
:
invalid_element_value_
{
0
}
{}
__host__
__device__
constexpr
StaticTensorTupleOfVectorBuffer
()
:
invalid_element_scalar_value_
{
0
}
{
}
__host__
__device__
constexpr
StaticTensorTupleOfVectorBuffer
(
S
invalid_element_value
)
__host__
__device__
constexpr
StaticTensorTupleOfVectorBuffer
(
S
invalid_element_value
)
:
invalid_element_value_
{
invalid_element_value
}
:
invalid_element_
scalar_
value_
{
invalid_element_value
}
{
{
}
}
...
@@ -125,11 +128,11 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -125,11 +128,11 @@ struct StaticTensorTupleOfVectorBuffer
{
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
{
return
S
{
0
}
;
return
zero_scalar_value_
;
}
}
else
else
{
{
return
invalid_element_value_
;
return
invalid_element_
scalar_
value_
;
}
}
}
}
}
}
...
@@ -153,7 +156,7 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -153,7 +156,7 @@ struct StaticTensorTupleOfVectorBuffer
}
}
else
else
{
{
return
ignore
;
return
ignore
d_element_scalar_
;
}
}
}
}
...
@@ -186,7 +189,7 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -186,7 +189,7 @@ struct StaticTensorTupleOfVectorBuffer
else
else
{
{
// TODO: is this right way to initialize a vector?
// TODO: is this right way to initialize a vector?
return
X
{
invalid_element_value_
};
return
X
{
invalid_element_
scalar_
value_
};
}
}
}
}
}
}
...
@@ -237,7 +240,9 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -237,7 +240,9 @@ struct StaticTensorTupleOfVectorBuffer
}
}
StaticBufferTupleOfVector
<
AddressSpace
,
S
,
num_of_vector_
,
ScalarPerVector
,
true
>
data_
;
StaticBufferTupleOfVector
<
AddressSpace
,
S
,
num_of_vector_
,
ScalarPerVector
,
true
>
data_
;
S
invalid_element_value_
=
S
{
0
};
static
constexpr
S
zero_scalar_value_
=
S
{
0
};
const
S
invalid_element_scalar_value_
=
S
{
0
};
S
ignored_element_scalar_
;
};
};
template
<
AddressSpaceEnum_t
AddressSpace
,
template
<
AddressSpaceEnum_t
AddressSpace
,
...
...
composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp
View file @
1b15b21a
...
@@ -114,6 +114,16 @@ struct BlockwiseTensorSliceTransfer_v4
...
@@ -114,6 +114,16 @@ struct BlockwiseTensorSliceTransfer_v4
}
}
}
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
RunRead
(
src_desc
,
src_buf
);
RunWrite
(
dst_desc
,
dst_buf
);
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v3r1.hpp
View file @
1b15b21a
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r2.hpp
View file @
1b15b21a
...
@@ -165,6 +165,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
...
@@ -165,6 +165,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_idx
[
I0
];
index_t
tmp
=
ordered_src_access_idx
[
I0
];
// TODO: BUG: should start at 1
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_idx
[
j
];
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_idx
[
j
];
});
});
...
@@ -412,6 +413,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
...
@@ -412,6 +413,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_idx
[
I0
];
index_t
tmp
=
ordered_dst_access_idx
[
I0
];
// TODO: BUG: should start at 1
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
});
});
...
@@ -512,7 +514,8 @@ struct ThreadwiseTensorSliceTransfer_v3r2
...
@@ -512,7 +514,8 @@ struct ThreadwiseTensorSliceTransfer_v3r2
template
<
typename
DstBuffer
>
template
<
typename
DstBuffer
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
{
constexpr
index_t
ntransform_dst
=
DstDesc
::
GetNumOfTransform
();
// TODO: why need remove_cvref_t ?
constexpr
index_t
ntransform_dst
=
remove_cvref_t
<
DstDesc
>::
GetNumOfTransform
();
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_dst
,
0
>::
type
{};
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_dst
,
0
>::
type
{};
...
@@ -545,6 +548,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
...
@@ -545,6 +548,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
forward_sweep_
(
I0
)
=
true
;
forward_sweep_
(
I0
)
=
true
;
// TODO: BUG: should start at 1
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_lengths
[
I0
]
-
1
;
index_t
tmp
=
ordered_src_access_lengths
[
I0
]
-
1
;
...
@@ -608,6 +612,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
...
@@ -608,6 +612,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_lengths
[
I0
]
-
1
;
index_t
tmp
=
ordered_dst_access_lengths
[
I0
]
-
1
;
// TODO: BUG: should start at 1
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_lengths
[
j
]
-
1
;
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_lengths
[
j
]
-
1
;
});
});
...
...
composable_kernel/include/utility/common_header.hpp
View file @
1b15b21a
...
@@ -35,8 +35,8 @@
...
@@ -35,8 +35,8 @@
#include "dynamic_buffer.hpp"
#include "dynamic_buffer.hpp"
#include "is_known_at_compile_time.hpp"
#include "is_known_at_compile_time.hpp"
#include "transpose_vectors.hpp"
#include "transpose_vectors.hpp"
#include "inner_product.hpp"
#include "inner_product.hpp"
#include "element_wise_operation.hpp"
// TODO: remove this
// TODO: remove this
#if CK_USE_AMD_INLINE_ASM
#if CK_USE_AMD_INLINE_ASM
...
...
composable_kernel/include/utility/config.hpp
View file @
1b15b21a
...
@@ -24,12 +24,16 @@
...
@@ -24,12 +24,16 @@
#define CK_MIN_BLOCK_PER_CU 2
#define CK_MIN_BLOCK_PER_CU 2
#endif
#endif
//
buffer resou
rs
e
//
GPU-specific paramete
rs
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A)
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A)
// buffer resourse
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
// wave size
#define CK_GPU_WAVE_SIZE 64
#elif defined(CK_AMD_GPU_GFX1030)
#elif defined(CK_AMD_GPU_GFX1030)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_GPU_WAVE_SIZE 32
#endif
#endif
// FMA instruction
// FMA instruction
...
...
composable_kernel/include/utility/utility.hpp
View file @
1b15b21a
...
@@ -5,8 +5,12 @@
...
@@ -5,8 +5,12 @@
namespace
ck
{
namespace
ck
{
__device__
constexpr
index_t
get_wave_size
()
{
return
CK_GPU_WAVE_SIZE
;
}
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_wave_local_1d_id
()
{
return
threadIdx
.
x
/
get_wave_size
();
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
}
// namespace ck
}
// namespace ck
...
...
device_operation/include/device_conv2d_fwd_xdl_output_shuffle_nhwc_kyxc_nhwk.hpp
View file @
1b15b21a
...
@@ -29,8 +29,8 @@ template <typename InDataType,
...
@@ -29,8 +29,8 @@ template <typename InDataType,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
K1
,
ck
::
index_t
MPerX
DL
,
ck
::
index_t
MPerX
dl
,
ck
::
index_t
NPerX
DL
,
ck
::
index_t
NPerX
dl
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
typename
ABlockTransferThreadSliceLengths_K0_M_K1
,
...
@@ -266,8 +266,8 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -266,8 +266,8 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
K0PerBlock
,
K0PerBlock
,
MPerX
DL
,
MPerX
dl
,
NPerX
DL
,
NPerX
dl
,
K1
,
K1
,
MXdlPerWave
,
MXdlPerWave
,
NXdlPerWave
,
NXdlPerWave
,
...
@@ -299,10 +299,12 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -299,10 +299,12 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
ABlockLdsAddExtraM
,
ABlockLdsAddExtraM
,
BBlockLdsAddExtraN
>
;
BBlockLdsAddExtraN
>
;
#if !DEBUG_USE_C_SHUFFLE
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
#endif
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
...
@@ -331,7 +333,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -331,7 +333,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
a_grid_desc_k0_m_k1_
{},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_m_n_
{},
#if !DEBUG_USE_C_SHUFFLE
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
#else
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
{},
#endif
block_2_ctile_map_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
M01_
{
M01
},
N01_
{
N01
},
N01_
{
N01
},
...
@@ -358,8 +364,15 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -358,8 +364,15 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
if
(
GridwiseGemm
::
CheckValidity
(
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
{
{
#if !DEBUG_USE_C_SHUFFLE
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
#else
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
(
c_grid_desc_m_n_
);
#endif
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
}
}
...
@@ -372,8 +385,15 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -372,8 +385,15 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
#if !DEBUG_USE_C_SHUFFLE
Block2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
#else
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
;
#endif
typename
GridwiseGemm
::
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
M01_
;
index_t
N01_
;
index_t
N01_
;
InElementwiseOperation
in_element_op_
;
InElementwiseOperation
in_element_op_
;
...
@@ -427,11 +447,17 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -427,11 +447,17 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
CDataType
,
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
#if !DEBUG_USE_C_SHUFFLE
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
#else
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
>
,
#endif
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
OutElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
Block2CTileMap
>
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
...
@@ -444,7 +470,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -444,7 +470,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
#if !DEBUG_USE_C_SHUFFLE
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
#else
arg
.
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
,
#endif
arg
.
in_element_op_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
arg
.
out_element_op_
,
...
@@ -458,11 +488,17 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -458,11 +488,17 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
CDataType
,
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
#if !DEBUG_USE_C_SHUFFLE
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
#else
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MRepeat_MWaveMPerXdl_NBlock_NRepeat_NWaveNPerXdl
>
,
#endif
InElementwiseOperation
,
InElementwiseOperation
,
WeiElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
OutElementwiseOperation
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
Block2CTileMap
>
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
...
@@ -475,7 +511,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
...
@@ -475,7 +511,11 @@ struct DeviceConv2dFwdXdl_Output_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
#if !DEBUG_USE_C_SHUFFLE
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
#else
arg
.
c_grid_desc_mblock_mrepeat_mwavemperxdl_nblock_nrepeat_nwavenperxdl_
,
#endif
arg
.
in_element_op_
,
arg
.
in_element_op_
,
arg
.
wei_element_op_
,
arg
.
wei_element_op_
,
arg
.
out_element_op_
,
arg
.
out_element_op_
,
...
...
device_operation/include/element_wise_operation.hpp
deleted
100644 → 0
View file @
2fd5e6ae
#ifndef ELEMENT_WISE_OPERATION_HPP
#define ELEMENT_WISE_OPERATION_HPP
namespace
ck
{
namespace
tensor_operation
{
namespace
element_wise
{
struct
PassThrough
{
template
<
typename
T
>
__host__
__device__
constexpr
T
operator
()(
T
v
)
const
{
return
v
;
}
};
struct
AddRelu
{
template
<
typename
T1
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
)
const
{
float
b
=
v0
+
v1
;
float
c
=
b
>
0
?
b
:
0
;
return
c
;
}
template
<
typename
T1
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
)
const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
return b;
#else
float
b
=
v1
+
v0
;
float
c
=
b
>
0
?
b
:
0
;
return
c
;
#endif
}
};
struct
AddReluAdd
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
float
b
=
v0
+
v1
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
}
template
<
typename
T1
,
typename
T2
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
#if 0
float a = v1 + v0;
float b = max(a, float(0));
float c = b + v2;
return c;
#else
float
b
=
v1
+
v2
;
float
c
=
(
v0
>
-
v1
)
?
b
+
v0
:
v2
;
return
c
;
#endif
}
};
struct
AddLeakyReluAdd
{
template
<
typename
T1
,
typename
T2
>
__host__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
float
a
=
v0
+
v1
;
float
b
=
0.1
*
a
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
}
template
<
typename
T1
,
typename
T2
>
__device__
constexpr
float
operator
()(
float
v0
,
T1
v1
,
T2
v2
)
const
{
#if 0
// this use not too many registers, but use fp64 mul
float a = v0 + v1;
float b = 0.1 * a;
float c = b > 0 ? b : 0;
float d = c + v2;
return d;
#elif
0
// this spill register
float
a
=
v0
+
v1
;
float
b
=
float
(
0.1
)
*
a
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
#elif 0
// this use lots of registers (but no spill)
constexpr
float
alpha
=
0.1
;
constexpr
float
alpha_inv
=
1.0
/
alpha
;
float
a
=
v2
*
alpha_inv
;
float
b
=
v1
+
v0
;
float
c
=
b
>
0
?
b
:
0
;
float
d
=
alpha
*
(
a
+
c
);
return
d
;
#elif 1
// this use lots of registers (but no spill), 89 Tflops
constexpr
float
alpha
=
0.1
;
constexpr
float
alpha_inv
=
1.0
/
alpha
;
float
a
=
v2
*
alpha_inv
;
float
b
=
v1
+
v0
;
float
c
=
max
(
b
,
float
(
0
));
float
d
=
alpha
*
(
a
+
c
);
return
d
;
#elif 1
// this spill registers, 89 Tflops
float
a
=
v0
+
v1
;
float
alpha
=
0.1
;
float
b
;
asm
volatile
(
"
\n
\
v_mul_f32_e32 %0, %1, %2
\n
\
"
:
"=v"
(
b
)
:
"s"
(
alpha
),
"v"
(
a
));
float
c
=
b
>
0
?
b
:
0
;
float
d
=
c
+
v2
;
return
d
;
#endif
}
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment