Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
24af0144
Unverified
Commit
24af0144
authored
Nov 12, 2022
by
Po Yen Chen
Committed by
GitHub
Nov 12, 2022
Browse files
Merge branch 'develop' into gemm_layernorm_welford
parents
961f5e9e
b79bbbc2
Changes
813
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1357 additions
and
360 deletions
+1357
-360
include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp
...ude/ck/tensor_operation/gpu/thread/threadwise_welford.hpp
+59
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+12
-1
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
...tion/operator_transform/transform_contraction_to_gemm.hpp
+288
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp
...nsor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp
+102
-71
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp
...tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp
+52
-39
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
...erence_tensor_operation/cpu/reference_conv_bwd_weight.hpp
+8
-3
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp
...ference_tensor_operation/cpu/reference_gemm_layernorm.hpp
+2
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp
...ry/reference_tensor_operation/cpu/reference_layernorm.hpp
+4
-3
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
...rary/reference_tensor_operation/cpu/reference_softmax.hpp
+9
-0
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+21
-2
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
...nsor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
+34
-6
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
...ration_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
+129
-0
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_weight.hpp
...or_operation_instance/gpu/convolution_backward_weight.hpp
+0
-230
library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp
...or_operation_instance/gpu/device_elementwise_instance.hpp
+1
-1
library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp
...nsor_operation_instance/gpu/elementwise_normalization.hpp
+79
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp
...ration_instance/gpu/grouped_convolution_backward_data.hpp
+90
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
...tion_instance/gpu/grouped_convolution_backward_weight.hpp
+235
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_bias_forward_perlayer_quantization.hpp
...rouped_convolution_bias_forward_perlayer_quantization.hpp
+114
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
...or_operation_instance/gpu/grouped_convolution_forward.hpp
+2
-2
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dl.hpp
...operation_instance/gpu/grouped_convolution_forward_dl.hpp
+116
-0
No files found.
include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp
View file @
24af0144
...
@@ -75,4 +75,63 @@ struct ThreadwiseWelford
...
@@ -75,4 +75,63 @@ struct ThreadwiseWelford
int
max_count_
;
int
max_count_
;
};
};
template
<
typename
T
,
typename
SrcMeanVarCountThreadDesc_M_K
,
typename
DstMeanVarThreadDesc_M
,
bool
GetActualVariance
=
false
>
struct
ThreadwiseWelfordMerge
{
static
constexpr
auto
src_thread_desc_m_k
=
SrcMeanVarCountThreadDesc_M_K
{};
static
constexpr
auto
dst_thread_desc_m
=
DstMeanVarThreadDesc_M
{};
static
constexpr
auto
src_length_m
=
src_thread_desc_m_k
.
GetLength
(
Number
<
0
>
{});
static
constexpr
auto
src_length_k
=
src_thread_desc_m_k
.
GetLength
(
Number
<
1
>
{});
static
constexpr
auto
dst_length_m
=
dst_thread_desc_m
.
GetLength
(
Number
<
0
>
{});
static_assert
(
src_length_m
==
dst_length_m
,
"lengths of source and dst buffer must match!"
);
__device__
static
void
Merge
(
T
&
mean_a
,
T
&
var_a
,
int32_t
&
count_a
,
T
mean_b
,
T
var_b
,
int32_t
count_b
)
{
int
count
=
count_a
+
count_b
;
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
type_convert
<
T
>
(
count_b
)
/
count
;
T
delta
=
mean_b
-
mean_a
;
mean_a
+=
delta
*
count_b_over_count
;
var_a
+=
var_b
+
delta
*
delta
*
count_a
*
count_b_over_count
;
count_a
=
count
;
}
template
<
typename
SrcMeanBufferType
,
typename
SrcVarBufferType
,
typename
SrcCountBufferType
,
typename
DstMeanBufferType
,
typename
DstVarBufferType
,
typename
DstCountBufferType
>
__device__
static
void
Run
(
const
SrcMeanBufferType
&
src_mean_buf
,
const
SrcVarBufferType
&
src_var_buf
,
const
SrcCountBufferType
&
src_count_buf
,
DstMeanBufferType
&
dst_mean_buf
,
DstVarBufferType
&
dst_var_buf
,
DstCountBufferType
&
dst_count_buf
)
{
static_for
<
0
,
src_length_m
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
src_length_k
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
src_offset
=
src_thread_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
Merge
(
dst_mean_buf
(
iM
),
dst_var_buf
(
iM
),
dst_count_buf
(
iM
),
src_mean_buf
[
Number
<
src_offset
>
{}],
src_var_buf
[
Number
<
src_offset
>
{}],
src_count_buf
[
Number
<
src_offset
>
{}]);
});
if
constexpr
(
GetActualVariance
)
{
dst_var_buf
(
iM
)
=
dst_var_buf
[
iM
]
/
dst_count_buf
[
iM
];
};
});
};
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
24af0144
...
@@ -594,6 +594,7 @@ struct XdlopsGemm
...
@@ -594,6 +594,7 @@ struct XdlopsGemm
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex4D
=
MultiIndex
<
4
>
;
__device__
static
constexpr
index_t
GetNumBlks
()
{
return
mfma_instr
.
num_output_blks
;
}
__device__
static
constexpr
index_t
GetNumBlks
()
{
return
mfma_instr
.
num_output_blks
;
}
...
@@ -822,6 +823,16 @@ struct XdlopsGemm
...
@@ -822,6 +823,16 @@ struct XdlopsGemm
return
TransposeC
?
CIndex
{
n_offset
,
m_offset
}
:
CIndex
{
m_offset
,
n_offset
};
return
TransposeC
?
CIndex
{
n_offset
,
m_offset
}
:
CIndex
{
m_offset
,
n_offset
};
}
}
__device__
static
CIndex4D
GetBeginOfThreadBlk4D
(
index_t
/* xdlops_i */
,
index_t
/* blk_i */
)
{
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
return
TransposeC
?
CIndex4D
{
blk_td
,
I0
,
blk_id
,
I0
}
:
CIndex4D
{
I0
,
blk_id
,
I0
,
blk_td
};
}
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
>
{};
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
>
{};
static
constexpr
auto
mfma_instr
=
mfma
.
selected_mfma
;
static
constexpr
auto
mfma_instr
=
mfma
.
selected_mfma
;
...
...
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
namespace
ck
{
namespace
tensor_operation
{
// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
device
::
TensorSpecialization
TensorSpec
>
static
auto
MakeGridDescriptorPair
(
const
std
::
vector
<
index_t
>&
gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
gs_ms_ns_strides_vec
)
{
if
(
!
(
gs_ms_ns_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
&&
gs_ms_ns_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
))
{
throw
std
::
runtime_error
(
"wrong! dimension must match input lengths"
);
}
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
gs_ms_ns_lengths
=
to_tuple
(
gs_ms_ns_lengths_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
gs_ms_ns_strides
=
to_tuple
(
gs_ms_ns_strides_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// dimension Ids for G0, G1, ...
constexpr
auto
gDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimG
,
1
>::
type
{};
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
,
NumDimG
+
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
+
NumDimM
,
NumDimG
+
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for G0, G1, ...
const
auto
gLengths
=
get_container_subset
(
gs_ms_ns_lengths
,
gDimIds
);
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
gs_ms_ns_lengths
,
mDimIds
);
// lengths for N0, N1, ...
const
auto
nLengths
=
get_container_subset
(
gs_ms_ns_lengths
,
nDimIds
);
if
constexpr
(
TensorSpec
==
device
::
TensorSpecialization
::
Packed
)
{
auto
G
=
container_reduce
(
gLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
grid_desc_g_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
G
,
M
,
N
),
make_tuple
(
gs_ms_ns_strides
[
Number
<
NumDimG
-
1
>
{}],
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
-
1
>
{}],
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
+
NumDimN
-
1
>
{}]));
const
auto
grid_desc_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
-
1
>
{}],
gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
+
NumDimN
-
1
>
{}]));
return
std
::
make_pair
(
grid_desc_g_mraw_nraw
,
grid_desc_mraw_nraw
);
}
else
{
// naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const
auto
grid_desc_gs_ms_ns
=
make_naive_tensor_descriptor
(
gs_ms_ns_lengths
,
gs_ms_ns_strides
);
// transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
// Note: This does not require padding as it only provides G offset calculation. Technically
// descriptor for only G is needed. Here we opt for backward compatibility purpose to return
// G_M_N
const
auto
grid_desc_g_mraw_nraw
=
transform_tensor_descriptor
(
grid_desc_gs_ms_ns
,
make_tuple
(
make_merge_transform
(
gLengths
),
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
gDimIds
,
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
c_ms_ns_lengths
=
to_tuple
(
gs_ms_ns_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
c_ms_ns_strides
=
to_tuple
(
gs_ms_ns_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const
auto
grid_desc_ms_ns
=
make_naive_tensor_descriptor
(
c_ms_ns_lengths
,
c_ms_ns_strides
);
const
auto
grid_desc_mraw_nraw
=
transform_tensor_descriptor
(
grid_desc_ms_ns
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
mDimIds
-
Number
<
NumDimG
>
{},
nDimIds
-
Number
<
NumDimG
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
std
::
make_pair
(
grid_desc_g_mraw_nraw
,
grid_desc_mraw_nraw
);
}
}
template
<
typename
NumDims_G_M_N_K_O
,
// Sequence<>
typename
PerBlock_M_N_K_O
,
// Sequence<>
device
::
GemmSpecialization
GemmSpec
,
device
::
TensorSpecialization
ASpec
,
device
::
TensorSpecialization
B0Spec
,
device
::
TensorSpecialization
B1Spec
,
device
::
TensorSpecialization
CSpec
>
struct
TransformBatchedContractionContractionToBatchedGemmGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
index_t
NumDimG
=
NumDims_G_M_N_K_O
::
At
(
I0
);
static
constexpr
index_t
NumDimM
=
NumDims_G_M_N_K_O
::
At
(
I1
);
static
constexpr
index_t
NumDimN
=
NumDims_G_M_N_K_O
::
At
(
I2
);
static
constexpr
index_t
NumDimK
=
NumDims_G_M_N_K_O
::
At
(
I3
);
static
constexpr
index_t
NumDimO
=
NumDims_G_M_N_K_O
::
At
(
I4
);
static
constexpr
index_t
MPerBlock
=
PerBlock_M_N_K_O
::
At
(
I0
);
static
constexpr
index_t
NPerBlock
=
PerBlock_M_N_K_O
::
At
(
I1
);
static
constexpr
index_t
KPerBlock
=
PerBlock_M_N_K_O
::
At
(
I2
);
static
constexpr
index_t
OPerBlock
=
PerBlock_M_N_K_O
::
At
(
I3
);
static
constexpr
auto
matrix_padder
=
device
::
GemmGemmPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
,
OPerBlock
};
//
// A
//
static
auto
MakeAGridDescriptorPair
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimK
,
ASpec
>
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
);
}
// TODO: rename to G_MRaw_KRaw
static
auto
MakeAGridDescriptor_G_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
return
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
first
;
}
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
return
matrix_padder
.
PadADescriptor_M_K
(
MakeAGridDescriptorPair
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
).
second
);
}
template
<
typename
AGridDesc_M_K
,
typename
Number
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
Number
&
AK1
)
{
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AK0
=
K
/
AK1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
//
// B (alias of B0)
//
static
auto
MakeB0GridDescriptorPair
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimN
,
NumDimK
,
B0Spec
>
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
);
}
// TODO: rename to G_MRaw_NRaw
static
auto
MakeB0GridDescriptor_G_N_K
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_strides_vec
)
{
return
MakeB0GridDescriptorPair
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
).
first
;
}
static
auto
MakeB0GridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b0_gs_ns_ks_strides_vec
)
{
// alias of matrix_padder.PadB0Descriptor_N_K
return
matrix_padder
.
PadBDescriptor_N_K
(
MakeB0GridDescriptorPair
(
b0_gs_ns_ks_lengths_vec
,
b0_gs_ns_ks_strides_vec
).
second
);
}
template
<
typename
BGridDesc_N_K
,
typename
Number
>
__host__
__device__
static
constexpr
auto
MakeB0GridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
Number
&
BK1
)
{
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
BK0
=
K
/
BK1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
//
// B1
//
static
auto
MakeB1GridDescriptorPair
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimO
,
NumDimN
,
B1Spec
>
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
);
}
// TODO: rename to G_NRaw_KRaw
static
auto
MakeB1GridDescriptor_G_N_K
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_strides_vec
)
{
return
MakeB1GridDescriptorPair
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
).
first
;
}
static
auto
MakeB1GridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_os_ns_strides_vec
)
{
// alias of matrix_padder.PadB1Descriptor_O_N
return
matrix_padder
.
PadB1Descriptor_N_K
(
MakeB1GridDescriptorPair
(
b1_gs_os_ns_lengths_vec
,
b1_gs_os_ns_strides_vec
).
second
);
}
template
<
typename
B1GridDesc_N_K
,
typename
Number
>
__host__
__device__
static
constexpr
auto
MakeB1GridDescriptor_BK0_N_BK1
(
const
B1GridDesc_N_K
&
b1_grid_desc_n_k
,
const
Number
&
B1K1
)
{
const
auto
N
=
b1_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b1_grid_desc_n_k
.
GetLength
(
I1
);
const
auto
B1K0
=
K
/
B1K1
;
return
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
//
// C
//
static
auto
MakeCGridDescriptorPair
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_os_strides_vec
)
{
return
MakeGridDescriptorPair
<
NumDimG
,
NumDimM
,
NumDimO
,
CSpec
>
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
);
}
// TODO: rename to G_MRaw_NRaw
static
auto
MakeCGridDescriptor_G_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_os_strides_vec
)
{
return
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
first
;
}
static
auto
MakeCGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
c_gs_ms_os_lengths_vec
,
const
std
::
vector
<
index_t
>&
c_gs_ms_os_strides_vec
)
{
return
matrix_padder
.
PadCDescriptor_M_N
(
MakeCGridDescriptorPair
(
c_gs_ms_os_lengths_vec
,
c_gs_ms_os_strides_vec
).
second
);
}
};
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward_nhwc_c.hpp
View file @
24af0144
...
@@ -9,46 +9,61 @@
...
@@ -9,46 +9,61 @@
#include <algorithm>
#include <algorithm>
#include <thread>
#include <thread>
#include "ck/utility/math_v2.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
host
{
namespace
host
{
template
<
typename
InOutDataType
,
typename
AccDataType
>
template
<
typename
XDataType
,
struct
ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
:
public
device
::
DeviceBatchNormFwd
<
4
,
3
>
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
>
struct
ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
:
public
device
::
DeviceBatchNormFwd
<
4
,
3
,
YElementwiseOp
>
{
{
struct
Argument
:
public
device
::
BaseArgument
struct
Argument
:
public
device
::
BaseArgument
{
{
Argument
(
const
std
::
array
<
index_t
,
4
>
xyLengths
,
Argument
(
const
std
::
array
<
index_t
,
4
>
xyLengths
,
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
int
,
3
>
reduceDims
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarStrides
,
const
std
::
array
<
index_t
,
1
>
bnScaleStrides
,
const
InOutDataType
*
p_x
,
const
std
::
array
<
index_t
,
1
>
bnBiasStrides
,
const
AccDataType
*
bnScale
,
const
std
::
array
<
index_t
,
1
>
bnMeanVarStrides
,
const
AccDataType
*
bnBias
,
const
XDataType
*
p_x
,
InOutDataType
*
p_y
,
const
ScaleDataType
*
bnScale
,
double
exponentialAverageFactor
,
const
BiasDataType
*
bnBias
,
AccDataType
*
resultRunningMean
,
AccDataType
*
resultRunningVariance
,
double
epsilon
,
double
epsilon
,
AccDataType
*
resultSaveMean
,
const
YElementwiseOp
y_elementwise_op
,
AccDataType
*
resultSaveInvVariance
)
YDataType
*
p_y
,
MeanVarDataType
*
resultSaveMean
,
MeanVarDataType
*
resultSaveInvVariance
,
double
averageFactor
,
MeanVarDataType
*
resultRunningMean
,
MeanVarDataType
*
resultRunningVariance
)
:
p_x_
(
p_x
),
:
p_x_
(
p_x
),
bnScale_
(
bnScale
),
bnScale_
(
bnScale
),
bnBias_
(
bnBias
),
bnBias_
(
bnBias
),
y_elementwise_op_
(
y_elementwise_op
),
p_y_
(
p_y
),
p_y_
(
p_y
),
resultRunningMean_
(
resultRunningMean
),
resultRunningVariance_
(
resultRunningVariance
),
resultSaveMean_
(
resultSaveMean
),
resultSaveMean_
(
resultSaveMean
),
resultSaveInvVariance_
(
resultSaveInvVariance
),
resultSaveInvVariance_
(
resultSaveInvVariance
),
exponentialAverageFactor_
(
exponentialAverageFactor
),
resultRunningMean_
(
resultRunningMean
),
epsilon_
(
epsilon
)
resultRunningVariance_
(
resultRunningVariance
)
{
{
(
void
)
xStrides
;
ignore
=
xStrides
;
(
void
)
yStrides
;
ignore
=
yStrides
;
(
void
)
bnScaleBiasMeanVarStrides
;
ignore
=
bnScaleStrides
;
ignore
=
bnBiasStrides
;
ignore
=
bnMeanVarStrides
;
ignore
=
reduceDims
;
if
(
xyLengths
.
size
()
!=
4
||
bnScaleBiasMeanVarLengths
.
size
()
!=
1
||
if
(
xyLengths
.
size
()
!=
4
||
bnScaleBiasMeanVarLengths
.
size
()
!=
1
||
bnScaleBiasMeanVarLengths
[
0
]
!=
xyLengths
[
3
])
bnScaleBiasMeanVarLengths
[
0
]
!=
xyLengths
[
3
])
...
@@ -59,26 +74,30 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
...
@@ -59,26 +74,30 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
w
=
xyLengths
[
2
];
w
=
xyLengths
[
2
];
c
=
xyLengths
[
3
];
c
=
xyLengths
[
3
];
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
averageFactor_
=
type_convert
<
AccDataType
>
(
averageFactor
);
resultSave
=
(
resultSaveMean
!=
nullptr
&&
resultSaveInvVariance
!=
nullptr
);
resultSave
=
(
resultSaveMean
!=
nullptr
&&
resultSaveInvVariance
!=
nullptr
);
resultRunning
=
(
resultRunningMean
!=
nullptr
&&
resultRunningVariance
!=
nullptr
);
resultRunning
=
(
resultRunningMean
!=
nullptr
&&
resultRunningVariance
!=
nullptr
);
}
}
const
InOutDataType
*
p_x_
;
const
XDataType
*
p_x_
;
const
AccDataType
*
bnScale_
;
const
ScaleDataType
*
bnScale_
;
const
AccDataType
*
bnBias_
;
const
BiasDataType
*
bnBias_
;
InOutDataType
*
p_y_
;
const
YElementwiseOp
y_elementwise_op_
;
YDataType
*
p_y_
;
Acc
DataType
*
result
Running
Mean_
;
MeanVar
DataType
*
result
Save
Mean_
;
Acc
DataType
*
result
Running
Variance_
;
MeanVar
DataType
*
result
SaveInv
Variance_
;
Acc
DataType
*
result
Save
Mean_
;
MeanVar
DataType
*
result
Running
Mean_
;
Acc
DataType
*
result
SaveInv
Variance_
;
MeanVar
DataType
*
result
Running
Variance_
;
bool
resultSave
,
resultRunning
;
bool
resultSave
,
resultRunning
;
index_t
n
,
h
,
w
,
c
;
index_t
n
,
h
,
w
,
c
;
double
exponentialA
verageFactor_
;
AccDataType
a
verageFactor_
;
doubl
e
epsilon_
;
AccDataTyp
e
epsilon_
;
};
};
struct
Invoker
:
public
device
::
BaseInvoker
struct
Invoker
:
public
device
::
BaseInvoker
...
@@ -86,14 +105,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
...
@@ -86,14 +105,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
float
Run
(
const
Argument
&
arg
)
float
Run
(
const
Argument
&
arg
)
{
{
auto
thread_reduce_func
=
[
&
](
auto
iC
)
{
auto
thread_reduce_func
=
[
&
](
auto
iC
)
{
AccDataType
reduceSize
=
type_convert
<
AccDataType
>
(
arg
.
n
)
*
type_convert
<
AccDataType
>
(
arg
.
h
)
*
type_convert
<
AccDataType
>
(
arg
.
w
);
index_t
offset_C
=
iC
;
index_t
offset_C
=
iC
;
AccDataType
mean
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
mean
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
meansquare
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
variance
=
type_convert
<
AccDataType
>
(
0.0
f
);
int32_t
curr_count
=
0
;
// compute mean,
meanquare, variance, invVariance
// compute mean,
variance using welford method
for
(
index_t
iN
=
0
;
iN
<
arg
.
n
;
iN
++
)
for
(
index_t
iN
=
0
;
iN
<
arg
.
n
;
iN
++
)
{
{
index_t
offset_N
=
iN
*
arg
.
h
*
arg
.
w
*
arg
.
c
;
index_t
offset_N
=
iN
*
arg
.
h
*
arg
.
w
*
arg
.
c
;
...
@@ -106,40 +123,46 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
...
@@ -106,40 +123,46 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
auto
offset
=
offset_N
+
offset_H
+
offset_W
+
offset_C
;
auto
offset
=
offset_N
+
offset_H
+
offset_W
+
offset_C
;
curr_count
++
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
offset
]);
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
offset
]);
mean
+=
x
;
AccDataType
delta
=
x
-
mean
;
meansquare
+=
x
*
x
;
mean
+=
delta
/
curr_count
;
AccDataType
delta2
=
x
-
mean
;
variance
+=
delta
*
delta2
;
};
};
}
}
};
};
mean
=
mean
/
reduceSize
;
// actual variance
meansquare
=
meansquare
/
reduceSize
;
variance
=
variance
/
curr_count
;
AccDataType
variance
=
meansquare
-
mean
*
mean
;
AccDataType
invVariance
=
AccDataType
invVariance
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
type_convert
<
AccDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
arg
.
epsilon_
+
variance
);
std
::
sqrt
(
type_convert
<
AccDataType
>
(
arg
.
epsilon_
)
+
variance
);
// save the mean/invVariance if required
// save the mean/invVariance if required
if
(
arg
.
resultSave
)
if
(
arg
.
resultSave
)
{
{
arg
.
resultSaveMean_
[
iC
]
=
mean
;
arg
.
resultSaveMean_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
mean
)
;
arg
.
resultSaveInvVariance_
[
iC
]
=
invVariance
;
arg
.
resultSaveInvVariance_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
invVariance
)
;
};
};
// update the moving average if required
// update the moving average if required
if
(
arg
.
resultRunning
)
if
(
arg
.
resultRunning
)
{
{
arg
.
resultRunningMean_
[
iC
]
=
AccDataType
oneMinusAverageFactor
=
arg
.
resultRunningMean_
[
iC
]
*
type_convert
<
AccDataType
>
(
1.0
)
-
arg
.
averageFactor_
;
type_convert
<
AccDataType
>
(
1.0
-
arg
.
exponentialAverageFactor_
)
+
arg
.
resultRunningMean_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
mean
*
arg
.
exponentialAverageFactor_
;
type_convert
<
AccDataType
>
(
arg
.
resultRunningMean_
[
iC
])
*
arg
.
resultRunningVariance_
[
iC
]
=
oneMinusAverageFactor
+
arg
.
resultRunningVariance_
[
iC
]
*
mean
*
arg
.
averageFactor_
);
type_convert
<
AccDataType
>
(
1.0
-
arg
.
exponentialAverageFactor_
)
+
arg
.
resultRunningVariance_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
variance
*
arg
.
exponentialAverageFactor_
;
arg
.
resultRunningVariance_
[
iC
]
*
oneMinusAverageFactor
+
variance
*
arg
.
averageFactor_
);
};
};
// Normalization
// Normalization
...
@@ -160,7 +183,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
...
@@ -160,7 +183,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
AccDataType
norm_x
=
AccDataType
norm_x
=
arg
.
bnScale_
[
iC
]
*
(
x
-
mean
)
*
invVariance
+
arg
.
bnBias_
[
iC
];
arg
.
bnScale_
[
iC
]
*
(
x
-
mean
)
*
invVariance
+
arg
.
bnBias_
[
iC
];
arg
.
p_y_
[
offset
]
=
type_convert
<
InOut
DataType
>
(
norm_x
);
arg
.
p_y_
[
offset
]
=
type_convert
<
Y
DataType
>
(
norm_x
);
};
};
}
}
};
};
...
@@ -207,34 +230,42 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
...
@@ -207,34 +230,42 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C : public device::DeviceBatch
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
4
>
xyLengths
,
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
4
>
xyLengths
,
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
int
,
3
>
reduceDims
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarStrides
,
const
std
::
array
<
index_t
,
1
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
1
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
1
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
p_x
,
const
void
*
bnScale
,
const
void
*
bnScale
,
const
void
*
bnBias
,
const
void
*
bnBias
,
void
*
p_y
,
double
exponentialAverageFactor
,
void
*
resultRunningMean
,
void
*
resultRunningVariance
,
double
epsilon
,
double
epsilon
,
const
YElementwiseOp
y_elementwise_op
,
void
*
p_y
,
void
*
resultSaveMean
,
void
*
resultSaveMean
,
void
*
resultSaveInvVariance
)
override
void
*
resultSaveInvVariance
,
double
averageFactor
,
void
*
resultRunningMean
,
void
*
resultRunningVariance
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
xyLengths
,
return
std
::
make_unique
<
Argument
>
(
xyLengths
,
xStrides
,
xStrides
,
yStrides
,
yStrides
,
reduceDims
,
bnScaleBiasMeanVarLengths
,
bnScaleBiasMeanVarLengths
,
bnScaleBiasMeanVarStrides
,
bnScaleStrides
,
static_cast
<
const
InOutDataType
*>
(
p_x
),
bnBiasStrides
,
static_cast
<
const
AccDataType
*>
(
bnScale
),
bnMeanVarStrides
,
static_cast
<
const
AccDataType
*>
(
bnBias
),
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
InOutDataType
*>
(
p_y
),
static_cast
<
const
ScaleDataType
*>
(
bnScale
),
exponentialAverageFactor
,
static_cast
<
const
BiasDataType
*>
(
bnBias
),
static_cast
<
AccDataType
*>
(
resultRunningMean
),
static_cast
<
AccDataType
*>
(
resultRunningVariance
),
epsilon
,
epsilon
,
static_cast
<
AccDataType
*>
(
resultSaveMean
),
y_elementwise_op
,
static_cast
<
AccDataType
*>
(
resultSaveInvVariance
));
static_cast
<
YDataType
*>
(
p_y
),
static_cast
<
MeanVarDataType
*>
(
resultSaveMean
),
static_cast
<
MeanVarDataType
*>
(
resultSaveInvVariance
),
averageFactor
,
static_cast
<
MeanVarDataType
*>
(
resultRunningMean
),
static_cast
<
MeanVarDataType
*>
(
resultRunningVariance
));
};
};
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
override
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp
View file @
24af0144
...
@@ -14,7 +14,12 @@ namespace ck {
...
@@ -14,7 +14,12 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
host
{
namespace
host
{
template
<
typename
InOutDataType
,
typename
AccDataType
>
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
>
struct
ReferenceBatchNormInfer_Input_N_H_W_C_Output_C
:
public
device
::
DeviceBatchNormInfer
<
4
,
3
>
struct
ReferenceBatchNormInfer_Input_N_H_W_C_Output_C
:
public
device
::
DeviceBatchNormInfer
<
4
,
3
>
{
{
struct
Argument
:
public
device
::
BaseArgument
struct
Argument
:
public
device
::
BaseArgument
...
@@ -23,14 +28,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
...
@@ -23,14 +28,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarStrides
,
const
std
::
array
<
index_t
,
1
>
bnScaleStrides
,
const
InOutDataType
*
p_x
,
const
std
::
array
<
index_t
,
1
>
bnBiasStrides
,
const
AccDataType
*
bnScale
,
const
std
::
array
<
index_t
,
1
>
bnMeanVarStrides
,
const
AccDataType
*
bnBias
,
const
XDataType
*
p_x
,
const
ScaleDataType
*
bnScale
,
const
BiasDataType
*
bnBias
,
double
epsilon
,
double
epsilon
,
const
Acc
DataType
*
estimatedMean
,
const
MeanVar
DataType
*
estimatedMean
,
const
Acc
DataType
*
estimatedVariance
,
const
MeanVar
DataType
*
estimatedVariance
,
InOut
DataType
*
p_y
)
Y
DataType
*
p_y
)
:
p_x_
(
p_x
),
:
p_x_
(
p_x
),
bnScale_
(
bnScale
),
bnScale_
(
bnScale
),
bnBias_
(
bnBias
),
bnBias_
(
bnBias
),
...
@@ -39,32 +46,34 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
...
@@ -39,32 +46,34 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
estimatedVariance_
(
estimatedVariance
),
estimatedVariance_
(
estimatedVariance
),
p_y_
(
p_y
)
p_y_
(
p_y
)
{
{
(
void
)
xStrides
;
ignore
=
xStrides
;
(
void
)
yStrides
;
ignore
=
yStrides
;
(
void
)
bnScaleBiasMeanVarStrides
;
ignore
=
bnScaleStrides
;
ignore
=
bnBiasStrides
;
ignore
=
bnMeanVarStrides
;
if
(
xyLengths
.
size
()
!=
4
||
bnScaleBiasMeanVarLengths
.
size
()
!=
1
||
if
(
xyLengths
.
size
()
!=
4
||
bnScaleBiasMeanVarLengths
.
size
()
!=
1
||
bnScaleBiasMeanVarLengths
[
0
]
!=
xyLengths
[
3
])
bnScaleBiasMeanVarLengths
[
0
]
!=
xyLengths
[
3
])
throw
std
::
runtime_error
(
"Invalid tensor dimensions!"
);
throw
std
::
runtime_error
(
"Invalid tensor dimensions!"
);
n
=
xyLengths
[
0
];
n
_
=
xyLengths
[
0
];
h
=
xyLengths
[
1
];
h
_
=
xyLengths
[
1
];
w
=
xyLengths
[
2
];
w
_
=
xyLengths
[
2
];
c
=
xyLengths
[
3
];
c
_
=
xyLengths
[
3
];
}
}
const
InOut
DataType
*
p_x_
;
const
X
DataType
*
p_x_
;
const
Acc
DataType
*
bnScale_
;
const
Scale
DataType
*
bnScale_
;
const
Acc
DataType
*
bnBias_
;
const
Bias
DataType
*
bnBias_
;
double
epsilon_
;
double
epsilon_
;
const
Acc
DataType
*
estimatedMean_
;
const
MeanVar
DataType
*
estimatedMean_
;
const
Acc
DataType
*
estimatedVariance_
;
const
MeanVar
DataType
*
estimatedVariance_
;
InOut
DataType
*
p_y_
;
Y
DataType
*
p_y_
;
index_t
n
,
h
,
w
,
c
;
index_t
n
_
,
h
_
,
w
_
,
c
_
;
};
};
struct
Invoker
:
public
device
::
BaseInvoker
struct
Invoker
:
public
device
::
BaseInvoker
...
@@ -81,15 +90,15 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
...
@@ -81,15 +90,15 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
std
::
sqrt
(
type_convert
<
AccDataType
>
(
arg
.
epsilon_
)
+
variance
);
std
::
sqrt
(
type_convert
<
AccDataType
>
(
arg
.
epsilon_
)
+
variance
);
// Normalization
// Normalization
for
(
index_t
iN
=
0
;
iN
<
arg
.
n
;
iN
++
)
for
(
index_t
iN
=
0
;
iN
<
arg
.
n
_
;
iN
++
)
{
{
index_t
offset_N
=
iN
*
arg
.
h
*
arg
.
w
*
arg
.
c
;
index_t
offset_N
=
iN
*
arg
.
h
_
*
arg
.
w
_
*
arg
.
c
_
;
for
(
index_t
iH
=
0
;
iH
<
arg
.
h
;
iH
++
)
for
(
index_t
iH
=
0
;
iH
<
arg
.
h
_
;
iH
++
)
{
{
index_t
offset_H
=
iH
*
arg
.
w
*
arg
.
c
;
index_t
offset_H
=
iH
*
arg
.
w
_
*
arg
.
c
_
;
for
(
index_t
iW
=
0
;
iW
<
arg
.
w
;
iW
++
)
for
(
index_t
iW
=
0
;
iW
<
arg
.
w
_
;
iW
++
)
{
{
index_t
offset_W
=
iW
*
arg
.
c
;
index_t
offset_W
=
iW
*
arg
.
c
_
;
auto
offset
=
offset_N
+
offset_H
+
offset_W
+
offset_C
;
auto
offset
=
offset_N
+
offset_H
+
offset_W
+
offset_C
;
...
@@ -98,21 +107,21 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
...
@@ -98,21 +107,21 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
AccDataType
norm_x
=
AccDataType
norm_x
=
arg
.
bnScale_
[
iC
]
*
(
x
-
mean
)
*
invVariance
+
arg
.
bnBias_
[
iC
];
arg
.
bnScale_
[
iC
]
*
(
x
-
mean
)
*
invVariance
+
arg
.
bnBias_
[
iC
];
arg
.
p_y_
[
offset
]
=
type_convert
<
InOut
DataType
>
(
norm_x
);
arg
.
p_y_
[
offset
]
=
type_convert
<
Y
DataType
>
(
norm_x
);
};
};
}
}
};
};
};
};
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
work_per_thread
=
(
arg
.
c
+
num_thread
-
1
)
/
num_thread
;
std
::
size_t
work_per_thread
=
(
arg
.
c
_
+
num_thread
-
1
)
/
num_thread
;
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
{
std
::
size_t
ic_begin
=
it
*
work_per_thread
;
std
::
size_t
ic_begin
=
it
*
work_per_thread
;
std
::
size_t
ic_end
=
std
::
min
(
static_cast
<
int
>
((
it
+
1
)
*
work_per_thread
),
arg
.
c
);
std
::
size_t
ic_end
=
std
::
min
(
static_cast
<
int
>
((
it
+
1
)
*
work_per_thread
),
arg
.
c
_
);
auto
f
=
[
=
]
{
auto
f
=
[
=
]
{
for
(
std
::
size_t
ic
=
ic_begin
;
ic
<
ic_end
;
++
ic
)
for
(
std
::
size_t
ic
=
ic_begin
;
ic
<
ic_end
;
++
ic
)
...
@@ -146,7 +155,9 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
...
@@ -146,7 +155,9 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarStrides
,
const
std
::
array
<
index_t
,
1
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
1
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
1
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
p_x
,
const
void
*
bnScale
,
const
void
*
bnScale
,
const
void
*
bnBias
,
const
void
*
bnBias
,
...
@@ -159,14 +170,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
...
@@ -159,14 +170,16 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
xStrides
,
xStrides
,
yStrides
,
yStrides
,
bnScaleBiasMeanVarLengths
,
bnScaleBiasMeanVarLengths
,
bnScaleBiasMeanVarStrides
,
bnScaleStrides
,
static_cast
<
const
InOutDataType
*>
(
p_x
),
bnBiasStrides
,
static_cast
<
const
AccDataType
*>
(
bnScale
),
bnMeanVarStrides
,
static_cast
<
const
AccDataType
*>
(
bnBias
),
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
ScaleDataType
*>
(
bnScale
),
static_cast
<
const
BiasDataType
*>
(
bnBias
),
epsilon
,
epsilon
,
static_cast
<
const
Acc
DataType
*>
(
estimatedMean
),
static_cast
<
const
MeanVar
DataType
*>
(
estimatedMean
),
static_cast
<
const
Acc
DataType
*>
(
estimatedVariance
),
static_cast
<
const
MeanVar
DataType
*>
(
estimatedVariance
),
static_cast
<
InOut
DataType
*>
(
p_y
));
static_cast
<
Y
DataType
*>
(
p_y
));
};
};
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
override
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
View file @
24af0144
...
@@ -131,17 +131,22 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
...
@@ -131,17 +131,22 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
else
if
constexpr
(
NDimSpatial
==
2
)
else
if
constexpr
(
NDimSpatial
==
2
)
{
{
auto
f_kcyx
=
[
&
](
auto
g
,
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
auto
f_kcyx
=
[
&
](
auto
g
,
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
std
::
size_t
N
=
arg
.
output_
.
GetLengths
()[
1
];
std
::
size_t
Ho
=
arg
.
output_
.
GetLengths
()[
3
];
std
::
size_t
Wo
=
arg
.
output_
.
GetLengths
()[
4
];
float
v_acc
=
0
;
float
v_acc
=
0
;
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
GetLengths
()[
1
]
;
++
n
)
for
(
std
::
size_t
n
=
0
;
n
<
N
;
++
n
)
{
{
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
output_
.
GetLengths
()[
3
]
;
++
ho
)
for
(
std
::
size_t
ho
=
0
;
ho
<
Ho
;
++
ho
)
{
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
GetLengths
()[
4
]
;
++
wo
)
for
(
std
::
size_t
wo
=
0
;
wo
<
Wo
;
++
wo
)
{
{
auto
wi
=
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp
View file @
24af0144
...
@@ -44,8 +44,8 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
...
@@ -44,8 +44,8 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
size_t
M
=
acc
.
mDesc
.
GetLengths
()[
0
];
size_t
M
=
acc
.
mDesc
.
GetLengths
()[
0
];
size_t
N
=
acc
.
mDesc
.
GetLengths
()[
1
];
size_t
N
=
acc
.
mDesc
.
GetLengths
()[
1
];
Tensor
<
ComputeDataType
>
avg_acc_sq
(
HostTensorDescriptor
(
std
::
vector
<
size_t
>
(
{
M
})
))
;
Tensor
<
ComputeDataType
>
avg_acc_sq
({
M
});
Tensor
<
ComputeDataType
>
avg_acc
(
HostTensorDescriptor
(
std
::
vector
<
size_t
>
(
{
M
})
))
;
Tensor
<
ComputeDataType
>
avg_acc
({
M
});
Tensor
<
ComputeDataType
>
acc_layernorm
(
acc
);
Tensor
<
ComputeDataType
>
acc_layernorm
(
acc
);
// reduce N dim
// reduce N dim
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp
View file @
24af0144
...
@@ -95,6 +95,7 @@ struct ReferenceLayernorm : public device::BaseOperator
...
@@ -95,6 +95,7 @@ struct ReferenceLayernorm : public device::BaseOperator
auto
x_val
=
ck
::
type_convert
<
AccDataType
>
(
arg
.
x_m_n_
(
m
,
n
));
auto
x_val
=
ck
::
type_convert
<
AccDataType
>
(
arg
.
x_m_n_
(
m
,
n
));
auto
y_val
=
(
x_val
-
mean
(
m
))
/
sqrt
(
var
(
m
)
+
arg
.
epsilon_
);
auto
y_val
=
(
x_val
-
mean
(
m
))
/
sqrt
(
var
(
m
)
+
arg
.
epsilon_
);
y_val
=
(
y_val
*
arg
.
gamma_n_
(
n
))
+
arg
.
beta_n_
(
n
);
y_val
=
(
y_val
*
arg
.
gamma_n_
(
n
))
+
arg
.
beta_n_
(
n
);
arg
.
acc_elementwise_op_
(
y_val
,
y_val
);
arg
.
y_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
YDataType
>
(
y_val
);
arg
.
y_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
YDataType
>
(
y_val
);
}
}
}
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
View file @
24af0144
...
@@ -60,6 +60,12 @@ struct ReferenceSoftmax : public device::BaseOperator
...
@@ -60,6 +60,12 @@ struct ReferenceSoftmax : public device::BaseOperator
{
{
scalar_lengths
.
push_back
(
arg
.
in_
.
mDesc
.
GetLengths
()[
dim
]);
scalar_lengths
.
push_back
(
arg
.
in_
.
mDesc
.
GetLengths
()[
dim
]);
}
}
// max and sum reduction with final reduced values of dim=0 is a scalar so give it
// appropriate lengths of {1}
if
(
arg
.
sm_scalar_dims_
.
size
()
==
0
)
{
scalar_lengths
.
push_back
(
1
);
}
Tensor
<
AccDataType
>
reduce_max
(
scalar_lengths
);
Tensor
<
AccDataType
>
reduce_max
(
scalar_lengths
);
reduce_max
.
GenerateTensorValue
(
reduce_max
.
GenerateTensorValue
(
...
@@ -67,6 +73,9 @@ struct ReferenceSoftmax : public device::BaseOperator
...
@@ -67,6 +73,9 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor
<
AccDataType
>
reduce_sum
(
scalar_lengths
);
Tensor
<
AccDataType
>
reduce_sum
(
scalar_lengths
);
reduce_sum
.
GenerateTensorValue
(
GeneratorTensor_1
<
AccDataType
>
{
0
});
reduce_sum
.
GenerateTensorValue
(
GeneratorTensor_1
<
AccDataType
>
{
0
});
// when final reduced values is of dim=0, the index will be transformed into empty
// std::vector which is actually a valid input for Tensor::operator(std::vector) and
// internally accesses 0'th element
auto
to_sm_scalar_idx
=
[
&
](
auto
idx
)
{
auto
to_sm_scalar_idx
=
[
&
](
auto
idx
)
{
std
::
vector
<
size_t
>
sm_scalar_idx
;
std
::
vector
<
size_t
>
sm_scalar_idx
;
for
(
index_t
dim
:
arg
.
sm_scalar_dims_
)
for
(
index_t
dim
:
arg
.
sm_scalar_dims_
)
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
24af0144
...
@@ -3,7 +3,10 @@
...
@@ -3,7 +3,10 @@
#pragma once
#pragma once
#include <cstdlib>
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/tuple.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -15,6 +18,8 @@ using F64 = double;
...
@@ -15,6 +18,8 @@ using F64 = double;
using
F32
=
float
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
...
@@ -23,6 +28,8 @@ using F16_F16_Tuple = ck::Tuple<F16, F16>;
...
@@ -23,6 +28,8 @@ using F16_F16_Tuple = ck::Tuple<F16, F16>;
using
F32_Tuple
=
ck
::
Tuple
<
F32
>
;
using
F32_Tuple
=
ck
::
Tuple
<
F32
>
;
using
I32_Tuple
=
ck
::
Tuple
<
I32
>
;
// GEMM layout
// GEMM layout
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
@@ -70,13 +77,25 @@ using NWGK = ck::tensor_layout::convolution::NWGK;
...
@@ -70,13 +77,25 @@ using NWGK = ck::tensor_layout::convolution::NWGK;
using
NHWGK
=
ck
::
tensor_layout
::
convolution
::
NHWGK
;
using
NHWGK
=
ck
::
tensor_layout
::
convolution
::
NHWGK
;
using
NDHWGK
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
using
NDHWGK
=
ck
::
tensor_layout
::
convolution
::
NDHWGK
;
//
using
GK
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
GK_TUPLE
=
ck
::
Tuple
<
GK
>
;
// pointwise functor
// pointwise functor
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Relu
=
ck
::
tensor_operation
::
element_wise
::
Relu
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
template
<
typename
DeviceOp
>
template
<
typename
Activation
>
using
Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul_Clamp
<
Activation
>
;
template
<
typename
Activation
>
using
Add_Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Add_Activation_Mul_Clamp
<
Activation
>
;
template
<
typename
DeviceOp
,
typename
Tag
=
void
>
struct
DeviceOperationInstanceFactory
;
struct
DeviceOperationInstanceFactory
;
}
// namespace instance
}
// namespace instance
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
View file @
24af0144
...
@@ -28,9 +28,26 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g
...
@@ -28,9 +28,26 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
false
>>>&
instances
);
void
add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemm
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
true
>>>&
instances
);
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
B0Layout
,
typename
B0Layout
,
...
@@ -39,7 +56,8 @@ template <typename ALayout,
...
@@ -39,7 +56,8 @@ template <typename ALayout,
typename
ADataType
,
typename
ADataType
,
typename
B0DataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
B1DataType
,
typename
CDataType
>
typename
CDataType
,
bool
MaskOutUpperTriangle
>
struct
DeviceOperationInstanceFactory
<
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemm
<
ALayout
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemm
<
ALayout
,
B0Layout
,
B0Layout
,
...
@@ -51,9 +69,10 @@ struct DeviceOperationInstanceFactory<
...
@@ -51,9 +69,10 @@ struct DeviceOperationInstanceFactory<
CDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
P
as
sThrough
>>
M
as
kOutUpperTriangle
>>
{
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemm
<
ALayout
,
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemm
<
ALayout
,
B0Layout
,
B0Layout
,
...
@@ -65,9 +84,10 @@ struct DeviceOperationInstanceFactory<
...
@@ -65,9 +84,10 @@ struct DeviceOperationInstanceFactory<
CDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
P
as
sThrough
>
;
M
as
kOutUpperTriangle
>
;
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
...
@@ -78,11 +98,19 @@ struct DeviceOperationInstanceFactory<
...
@@ -78,11 +98,19 @@ struct DeviceOperationInstanceFactory<
{
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
is_same_v
<
B1Layout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
B1Layout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
if
constexpr
(
MaskOutUpperTriangle
)
{
add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
op_ptrs
);
}
else
{
{
add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
op_ptrs
);
op_ptrs
);
}
}
}
}
}
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_
masking_scale_
softmax_gemm_permute.hpp
→
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
View file @
24af0144
...
@@ -17,63 +17,89 @@ namespace tensor_operation {
...
@@ -17,63 +17,89 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
template
<
ck
::
index_t
...
Is
>
void
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
using
S
=
ck
::
Sequence
<
Is
...
>
;
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
using
CPermuteNumDims_G_M_O
=
1
,
S
<
2
,
1
,
1
>
;
// "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
void
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
Row
,
std
::
vector
<
Col
,
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
Row
,
1
,
CPermuteNumDims_G_M_O
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
template
<
typename
ALayout
,
template
<
typename
ADataType
,
typename
B0Layout
,
typename
B1Layout
,
typename
CPermuteNumDims_G_M_Gemm1N
,
typename
ADataType
,
typename
B0DataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
B1DataType
,
typename
CDataType
>
typename
CDataType
,
MaskingSpecialization
MaskingSpec
>
struct
DeviceOperationInstanceFactory
<
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
B0Layout
,
1
,
B1Layout
,
1
,
CPermuteNumDims_G_M_Gemm1N
,
1
,
1
,
ADataType
,
ADataType
,
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
>>
PassThrough
,
MaskingSpec
>>
{
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
B0Layout
,
1
,
B1Layout
,
1
,
CPermuteNumDims_G_M_Gemm1N
,
1
,
1
,
ADataType
,
ADataType
,
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
>
;
PassThrough
,
MaskingSpec
>
;
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
...
@@ -82,11 +108,14 @@ struct DeviceOperationInstanceFactory<
...
@@ -82,11 +108,14 @@ struct DeviceOperationInstanceFactory<
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
B0DataType
,
half_t
>
&&
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
B0DataType
,
half_t
>
&&
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
is_same_v
<
B1Layout
,
Row
>
&&
{
is_same_v
<
CPermuteNumDims_G_M_Gemm1N
,
CPermuteNumDims_G_M_O
>
)
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
}
else
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
{
{
add_device_batched_gemm_
masking_scale_
softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
s
(
op_ptrs
);
op_ptrs
);
}
}
}
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_weight.hpp
deleted
100644 → 0
View file @
961f5e9e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// conv1d backward weight
void
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
1
,
NWC
,
KXC
,
NWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
1
,
NWC
,
KXC
,
NWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
1
,
NWC
,
KXC
,
NWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv2d backward weight
void
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
2
,
NHWC
,
KYXC
,
NHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
2
,
NHWC
,
KYXC
,
NHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
2
,
NHWC
,
KYXC
,
NHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv3d backward weight
void
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdWeight
<
3
,
NDHWC
,
KZYXC
,
NDHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
NWC
>
&&
is_same_v
<
WeiLayout
,
KXC
>
&&
is_same_v
<
OutLayout
,
NWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv1d_bwd_weight_xdl_nwc_kxc_nwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
NHWC
>
&&
is_same_v
<
WeiLayout
,
KYXC
>
&&
is_same_v
<
OutLayout
,
NHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWC
>
&&
is_same_v
<
WeiLayout
,
KZYXC
>
&&
is_same_v
<
OutLayout
,
NDHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_conv3d_bwd_weight_xdl_ndhwc_kzyxc_ndhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp
View file @
24af0144
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/device/
impl/
device_elementwise.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
...
...
library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// FP16
void
add_device_elementwise_normalization_rank_2_1_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwiseNormalization
<
ck
::
Tuple
<
F16
,
F16
>
,
F16
,
F16
,
F32
,
F16
,
element_wise
::
Add
,
PassThrough
,
2
,
1
>>>&
);
template
<
typename
InDataTypeTuple
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceElementwiseNormalization
<
InDataTypeTuple
,
GammaDataType
,
BetaDataType
,
F32
,
YDataType
,
ck
::
tensor_operation
::
element_wise
::
Add
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Rank
,
NumReduceDim
>>
{
using
DeviceOp
=
DeviceElementwiseNormalization
<
InDataTypeTuple
,
GammaDataType
,
BetaDataType
,
F32
,
YDataType
,
ck
::
tensor_operation
::
element_wise
::
Add
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Rank
,
NumReduceDim
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
GammaDataType
,
F16
>
&&
is_same_v
<
BetaDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
)
{
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
{
add_device_elementwise_normalization_rank_2_1_f16_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// conv2d backward data
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
OutLayout
,
typename
WeiLayout
,
typename
InLayout
,
typename
OutDataType
,
typename
WeiDataType
,
typename
InDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdDataMultipleD
<
NumDimSpatial
,
OutLayout
,
WeiLayout
,
Empty_Tuple
,
InLayout
,
OutDataType
,
WeiDataType
,
Empty_Tuple
,
InDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGroupedConvBwdDataMultipleD
<
NumDimSpatial
,
OutLayout
,
WeiLayout
,
Empty_Tuple
,
InLayout
,
OutDataType
,
WeiDataType
,
Empty_Tuple
,
InDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
WeiDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// conv1d backward weight
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
1
,
GNWC
,
GKXC
,
GNWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv2d backward weight
void
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv3d backward weight
void
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
GNDHWC
,
GKZYXC
,
GNDHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGroupedConvBwdWeight
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
1
&&
is_same_v
<
InLayout
,
GNWC
>
&&
is_same_v
<
WeiLayout
,
GKXC
>
&&
is_same_v
<
OutLayout
,
GNWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
GKZYXC
>
&&
is_same_v
<
OutLayout
,
GNDHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_bias_forward_perlayer_quantization.hpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void
add_device_conv2d_bias_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
GK_TUPLE
,
GNHWK
,
int8_t
,
int8_t
,
I32_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
Add_Activation_Mul_Clamp
<
PassThrough
>>>>&
instances
);
void
add_device_conv2d_bias_relu_perlayer_quantization_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleD
<
2
,
GNHWC
,
GKYXC
,
GK_TUPLE
,
GNHWK
,
int8_t
,
int8_t
,
I32_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
Add_Activation_Mul_Clamp
<
Relu
>>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
DsLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
DsDataType
,
typename
OutDataType
,
typename
Activation
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DsLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
DsDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Add_Activation_Mul_Clamp
<
Activation
>>>
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
DsLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
DsDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Add_Activation_Mul_Clamp
<
Activation
>>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
DsLayout
,
GK_TUPLE
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
DsDataType
,
I32_Tuple
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
if
constexpr
(
is_same_v
<
Activation
,
PassThrough
>
)
add_device_conv2d_bias_perlayer_quantization_int8_instances
(
op_ptrs
);
else
if
constexpr
(
is_same_v
<
Activation
,
Relu
>
)
add_device_conv2d_bias_relu_perlayer_quantization_int8_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
View file @
24af0144
...
@@ -3,11 +3,11 @@
...
@@ -3,11 +3,11 @@
#pragma once
#pragma once
#include <
cstdlib
>
#include <
vector
>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dl.hpp
0 → 100644
View file @
24af0144
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwd
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwd
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwd
<
2
,
GNHWC
,
GKYXC
,
GNHWK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwd
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGroupedConvFwd
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
OutLayout
,
InDataType
,
WeiDataType
,
OutDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
float
>
&&
is_same_v
<
WeiDataType
,
float
>
&&
is_same_v
<
OutDataType
,
float
>
)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
InDataType
,
int8_t
>
&&
is_same_v
<
WeiDataType
,
int8_t
>
&&
is_same_v
<
OutDataType
,
int8_t
>
)
{
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
…
11
12
13
14
15
16
17
18
19
…
41
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment