Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d85b89bb
Commit
d85b89bb
authored
Aug 02, 2019
by
Tejash Shah
Browse files
Used half4 datatype for blockwise gemm in place of half datatype
parent
2185affb
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
202 additions
and
104 deletions
+202
-104
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
...t_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
+2
-2
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
...osable_kernel/include/tensor_operation/blockwise_gemm.hpp
+42
-16
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
...sable_kernel/include/tensor_operation/threadwise_gemm.hpp
+38
-55
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+0
-1
composable_kernel/include/utility/amd_inline_asm.hpp
composable_kernel/include/utility/amd_inline_asm.hpp
+13
-10
composable_kernel/include/utility/config_amd.hpp.in
composable_kernel/include/utility/config_amd.hpp.in
+13
-0
composable_kernel/include/utility/config_nvidia.hpp.in
composable_kernel/include/utility/config_nvidia.hpp.in
+13
-0
composable_kernel/include/utility/vector_type.hpp
composable_kernel/include/utility/vector_type.hpp
+64
-19
driver/src/driver.cpp
driver/src/driver.cpp
+17
-1
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_fp16_bfp16_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
d85b89bb
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
View file @
d85b89bb
...
...
@@ -6,7 +6,7 @@
#include "threadwise_gemm.hpp"
#ifndef CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
#define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
1
#define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
0
#endif
namespace
ck
{
...
...
@@ -136,9 +136,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr
auto
b_thread_mtx
=
make_ConstantMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0Cluster
*
MLevel1Cluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
;
...
...
@@ -153,6 +150,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
std
::
is_same
<
FloatB
,
float
>::
value
>
{}([
&
](
auto
)
{
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
Float4
*
reg_a
=
reinterpret_cast
<
Float4
*>
(
p_a_thread
);
Float4
*
reg_b
=
reinterpret_cast
<
Float4
*>
(
p_b_thread
);
Float4
*
reg_c
=
reinterpret_cast
<
Float4
*>
(
p_c_thread
);
...
...
@@ -183,33 +183,39 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
}).
Else
([
&
](
auto
)
{
// If A and B datatype is bfloat16/float16
using
Half4x4
=
vector_type
<
vector_type
<
half
,
4
>
,
4
>
;
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()
*
4
];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()
*
4
];
using
Half4x4
=
vector_type
<
vector_type
<
half
,
4
>::
MemoryType
,
4
>::
MemoryType
;
using
Float4
=
vector_type
<
float
,
4
>::
MemoryType
;
Half4x4
*
reg_a
=
reinterpret_cast
<
Half4x4
*>
(
p_a_thread
);
Half4x4
*
reg_b
=
reinterpret_cast
<
Half4x4
*>
(
p_b_thread
);
Float4
*
reg_c
=
reinterpret_cast
<
Float4
*>
(
p_c_thread
);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[
mMyThreadOffsetA
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[
mMyThreadOffsetB
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
NPerLevel1Cluster
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
MPerLevel1Cluster
]);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[
mMyThreadOffsetA
*
4
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[
mMyThreadOffsetB
*
4
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[
(
mMyThreadOffsetB
+
NPerLevel1Cluster
)
*
4
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[
(
mMyThreadOffsetA
+
MPerLevel1Cluster
)
*
4
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
#pragma unroll
for
(
index_t
k
=
1
;
k
<
K
;
++
k
)
{
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
]);
reg_a
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[(
mMyThreadOffsetA
+
k
*
M
)
*
4
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
0
],
reg_c
[
8
],
reg_c
[
10
],
reg_c
[
12
],
reg_c
[
14
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
]);
reg_b
[
0
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[(
mMyThreadOffsetB
+
k
*
N
)
*
4
]);
outerProduct4x4
(
reg_a
[
1
],
reg_b
[
1
],
reg_c
[
9
],
reg_c
[
11
],
reg_c
[
13
],
reg_c
[
15
]);
reg_b
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_b_block
[
mMyThreadOffsetB
+
k
*
N
+
NPerLevel1Cluster
]);
&
p_b_block
[
(
mMyThreadOffsetB
+
k
*
N
+
NPerLevel1Cluster
)
*
4
]);
reg_a
[
1
]
=
*
reinterpret_cast
<
const
Half4x4
*>
(
&
p_a_block
[
mMyThreadOffsetA
+
k
*
M
+
MPerLevel1Cluster
]);
&
p_a_block
[
(
mMyThreadOffsetA
+
k
*
M
+
MPerLevel1Cluster
)
*
4
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
[
0
],
reg_c
[
2
],
reg_c
[
4
],
reg_c
[
6
]);
outerProduct4x4
(
reg_a
[
0
],
reg_b
[
1
],
reg_c
[
1
],
reg_c
[
3
],
reg_c
[
5
],
reg_c
[
7
]);
}
...
...
@@ -447,6 +453,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
FloatC
*
__restrict__
p_c_thread
)
const
{
// The assembly path doesn't support bfloat16 using asm instructions
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
static_if
<
std
::
is_same
<
FloatA
,
ushort
>::
value
&&
std
::
is_same
<
FloatB
,
ushort
>::
value
>
{}(
[
&
](
auto
)
{
Run_source
(
p_a_block
,
p_b_block
,
p_c_thread
);
})
...
...
@@ -454,7 +462,25 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
Run_amd_asm
(
p_a_block
,
p_b_block
,
p_c_thread
);
});
#else
static_if
<
std
::
is_same
<
FloatA
,
half
>::
value
&&
std
::
is_same
<
FloatB
,
half
>::
value
>
{}([
&
](
auto
)
{
// Vectorize the pointer to match with how half/bfloat16 datatypes are
// processed in gemm operation. Half type packs 4 half values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with a single value in mind (e.g. float),
// to retain the same 2D indexes for half/bfloat16, we recast datatype
// from a single half to 4 packed half/2 packed bfloat16 respectively.
const
vector_type
<
half
,
4
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
vector_type
<
half
,
4
>::
MemoryType
*>
(
p_a_block
);
const
vector_type
<
half
,
4
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
vector_type
<
half
,
4
>::
MemoryType
*>
(
p_b_block
);
Run_source
(
p_a_block_vec
,
p_b_block_vec
,
p_c_thread
);
}).
Else
([
&
](
auto
)
{
// If A and B datatype is bfloat16/float16
Run_source
(
p_a_block
,
p_b_block
,
p_c_thread
);
});
#endif
}
};
...
...
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
View file @
d85b89bb
...
...
@@ -56,11 +56,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
}
}
}).
Else
([
&
](
auto
)
{
static_if
<
std
::
is_same
<
Float
,
half
>::
value
>
{}([
&
](
auto
)
{
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4 respectively)
using
vector_t
=
typename
vector_type
<
Float
,
4
>::
MemoryType
;
}).
Else
([
&
](
auto
)
{
// fp16/bfp16
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
...
...
@@ -68,26 +64,10 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
const
index_t
src_index
=
src_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
dst_index
=
dst_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
*
reinterpret_cast
<
vector_
t
*>
(
&
p_dst
[
dst_index
*
4
])
=
*
reinterpret_cast
<
const
vector_
t
*>
(
&
p_src
[
src_index
*
4
]);
*
reinterpret_cast
<
Floa
t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
Floa
t
*>
(
&
p_src
[
src_index
]);
}
}
}).
Else
([
&
](
auto
)
{
using
vector_t
=
typename
vector_type
<
Float
,
2
>::
MemoryType
;
for
(
index_t
i
=
0
;
i
<
NRow
;
++
i
)
{
for
(
index_t
j
=
0
;
j
<
NCol
;
++
j
)
{
const
index_t
src_index
=
src_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
const
index_t
dst_index
=
dst_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_index
*
2
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_index
*
2
]);
}
}
});
});
}
...
...
@@ -129,32 +109,35 @@ __device__ void threadwise_gemm(MatrixA,
const
index_t
bindex
=
b_mtx
.
GetOffsetFromMultiIndex
(
k
,
j
);
const
index_t
cindex
=
c_mtx
.
GetOffsetFromMultiIndex
(
i
,
j
);
static_if
<
std
::
is_same
<
FloatA
,
float
>::
value
>
{}([
&
](
auto
)
{
p_c_thread
[
cindex
]
+=
CVT_FLOAT2ACCUM
(
p_a_thread
[
aindex
])
*
CVT_FLOAT2ACCUM
(
p_b_thread
[
bindex
]);
}).
Else
([
&
](
auto
)
{
static_if
<
std
::
is_same
<
FloatA
,
half
>::
value
>
{}([
&
](
auto
)
{
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
#if MIOPEN_USE_FP32 == 1
p_c_thread
[
cindex
]
+=
CVT_FLOAT2ACCUM
(
p_a_thread
[
aindex
])
*
CVT_FLOAT2ACCUM
(
p_b_thread
[
bindex
]);
#elif MIOPEN_USE_FP16 == 1
const
half
*
p_a_thread_half
=
reinterpret_cast
<
const
half
*>
(
&
p_a_thread
[
aindex
]);
const
half
*
p_b_thread_half
=
reinterpret_cast
<
const
half
*>
(
&
p_b_thread
[
bindex
]);
float
acc
=
0.0
;
for
(
index_t
v
=
0
;
v
<
4
;
++
v
)
{
acc
+=
CVT_FLOAT2ACCUM
(
p_a_thread
[
aindex
*
4
+
v
])
*
CVT_FLOAT2ACCUM
(
p_b_thread
[
bindex
*
4
+
v
]);
acc
+=
CVT_FLOAT2ACCUM
(
p_a_thread
_half
[
v
])
*
CVT_FLOAT2ACCUM
(
p_b_thread
_half
[
v
]);
}
p_c_thread
[
cindex
]
=
acc
;
}).
Else
([
&
](
auto
)
{
// If src/dst matrix datatype is bfloat16/float16 (vector size 2/4
// respectively)
p_c_thread
[
cindex
]
+=
acc
;
#elif MIOPEN_USE_BF16 == 1
const
ushort
*
p_a_thread_ushort
=
reinterpret_cast
<
const
ushort
*>
(
&
p_a_thread
[
aindex
]);
const
ushort
*
p_b_thread_ushort
=
reinterpret_cast
<
const
ushort
*>
(
&
p_b_thread
[
bindex
]);
float
acc
=
0.0
;
for
(
index_t
v
=
0
;
v
<
2
;
++
v
)
{
acc
+=
CVT_FLOAT2ACCUM
(
p_a_thread
[
aindex
*
2
+
v
])
*
CVT_FLOAT2ACCUM
(
p_b_thread
[
bindex
*
2
+
v
]);
acc
+=
CVT_FLOAT2ACCUM
(
p_a_thread
_ushort
[
v
])
*
CVT_FLOAT2ACCUM
(
p_b_thread
_ushort
[
v
]);
}
p_c_thread
[
cindex
]
+=
acc
;
});
});
#else
#endif
}
}
}
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
d85b89bb
...
...
@@ -112,7 +112,6 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
static_if
<
std
::
is_same
<
vector_src_t
,
vector_dest_t
>::
value
>
{}([
&
](
auto
)
{
*
reinterpret_cast
<
vector_dest_t
*>
(
&
p_dst
[
dst_index
])
=
*
reinterpret_cast
<
const
vector_src_t
*>
(
&
p_src
[
src_index
]);
//printf("%f ", static_cast<float>(p_dst[dst_index]));
}).
Else
([
&
](
auto
)
{
for
(
unsigned
int
data_idx
=
0
;
data_idx
<
DataPerAccess
;
++
data_idx
)
{
...
...
composable_kernel/include/utility/amd_inline_asm.hpp
View file @
d85b89bb
...
...
@@ -147,8 +147,9 @@ __device__ void outerProduct1x4(const half2* a, const half2* b, float* c)
"3"
(
c
[
3
]));
// 3rd Src Acc registers for 2 half2 registers
}
__device__
void
outerProduct1x4Half
(
const
vector_type
<
half
,
4
>&
a
,
const
vector_type
<
vector_type
<
half
,
4
>
,
4
>&
b
,
__device__
void
outerProduct1x4Half
(
const
vector_type
<
half
,
4
>::
MemoryType
&
a
,
const
vector_type
<
vector_type
<
half
,
4
>::
MemoryType
,
4
>::
MemoryType
&
b
,
vector_type
<
float
,
4
>::
MemoryType
&
c
)
{
outerProduct1x4
(
reinterpret_cast
<
const
half2
*>
(
&
a
),
...
...
@@ -156,14 +157,16 @@ __device__ void outerProduct1x4Half(const vector_type<half, 4>& a,
reinterpret_cast
<
float
*>
(
&
c
));
}
__device__
void
outerProduct4x4
(
const
vector_type
<
vector_type
<
half
,
4
>
,
4
>&
a
,
const
vector_type
<
vector_type
<
half
,
4
>
,
4
>&
b
,
__device__
void
outerProduct4x4
(
const
vector_type
<
vector_type
<
half
,
4
>::
MemoryType
,
4
>::
MemoryType
&
a
,
const
vector_type
<
vector_type
<
half
,
4
>::
MemoryType
,
4
>::
MemoryType
&
b
,
vector_type
<
float
,
4
>::
MemoryType
&
c0
,
vector_type
<
float
,
4
>::
MemoryType
&
c1
,
vector_type
<
float
,
4
>::
MemoryType
&
c2
,
vector_type
<
float
,
4
>::
MemoryType
&
c3
)
{
const
vector_type
<
half
,
4
>*
reg_a
=
reinterpret_cast
<
const
vector_type
<
half
,
4
>*>
(
&
a
);
const
vector_type
<
half
,
4
>::
MemoryType
*
reg_a
=
reinterpret_cast
<
const
vector_type
<
half
,
4
>::
MemoryType
*>
(
&
a
);
outerProduct1x4Half
(
reg_a
[
0
],
b
,
c0
);
outerProduct1x4Half
(
reg_a
[
1
],
b
,
c1
);
outerProduct1x4Half
(
reg_a
[
2
],
b
,
c2
);
...
...
composable_kernel/include/utility/config_amd.hpp.in
View file @
d85b89bb
...
...
@@ -15,6 +15,19 @@ namespace ck {
// instruction
typedef float float2_t __attribute__((ext_vector_type(2)));
typedef float float4_t __attribute__((ext_vector_type(4)));
typedef half2 half2_t;
typedef struct
{
half2 vector[2];
} half4_t;
typedef struct
{
ushort vector[2];
} ushort2_t;
typedef struct
{
ushort2_t vector[2];
} ushort4_t;
using index_t = uint32_t;
...
...
composable_kernel/include/utility/config_nvidia.hpp.in
View file @
d85b89bb
...
...
@@ -19,6 +19,19 @@ namespace ck {
// instruction,
using float2_t = float2;
using float4_t = float4;
typedef half2 half2_t;
typedef struct
{
half2 vector[2];
} half4_t;
typedef struct
{
ushort vector[2];
} ushort2_t;
typedef struct
{
ushort2_t vector[2];
} ushort4_t;
using index_t = uint32_t;
...
...
composable_kernel/include/utility/vector_type.hpp
View file @
d85b89bb
...
...
@@ -10,7 +10,11 @@ namespace ck {
template
<
class
T
,
index_t
N
>
struct
vector_type
{
typedef
struct
{
T
vector
[
N
];
}
MemoryType
;
MemoryType
mData
;
};
template
<
>
...
...
@@ -33,9 +37,7 @@ struct vector_type<float, 2>
{
using
MemoryType
=
float2_t
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
2
;
}
union
Data
union
DataType
{
MemoryType
vector
;
float
scalar
[
2
];
...
...
@@ -48,6 +50,13 @@ struct vector_type<float, 2>
*
(
reinterpret_cast
<
float
*>
(
&
v
)
+
I
)
=
s
;
}
__host__
__device__
static
MemoryType
Pack
(
float
s0
,
float
s1
)
{
DataType
data
;
data
.
scalar
[
0
]
=
s0
;
data
.
scalar
[
1
]
=
s1
;
return
data
.
vector
;
}
};
template
<
>
...
...
@@ -83,9 +92,9 @@ struct vector_type<half, 1>
template
<
>
struct
vector_type
<
half
,
2
>
{
using
MemoryType
=
half2
;
using
MemoryType
=
half2
_t
;
union
Data
union
Data
Type
{
MemoryType
vector
;
half
scalar
[
2
];
...
...
@@ -100,17 +109,25 @@ struct vector_type<half, 2>
*
(
reinterpret_cast
<
half
*>
(
&
v
)
+
I
)
=
s
;
}
__host__
__device__
static
MemoryType
Pack
(
half
s0
,
half
s1
)
{
DataType
data
;
data
.
scalar
[
0
]
=
s0
;
data
.
scalar
[
1
]
=
s1
;
return
data
.
vector
;
}
};
template
<
>
struct
vector_type
<
half
,
4
>
{
typedef
struct
MemoryType
{
half2
vector
[
2
];
}
MemoryType
;
using
MemoryType
=
half4_t
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
4
;
}
union
DataType
{
MemoryType
vector
;
half
scalar
[
4
];
};
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
half
s
,
Number
<
I
>
)
...
...
@@ -118,6 +135,16 @@ struct vector_type<half, 4>
static_assert
(
I
<
4
,
"wrong"
);
*
(
reinterpret_cast
<
half
*>
(
&
v
)
+
I
)
=
s
;
}
__host__
__device__
static
MemoryType
Pack
(
half
s0
,
half
s1
,
half
s2
,
half
s3
)
{
DataType
data
;
data
.
scalar
[
0
]
=
s0
;
data
.
scalar
[
1
]
=
s1
;
data
.
scalar
[
2
]
=
s2
;
data
.
scalar
[
3
]
=
s3
;
return
data
.
vector
;
}
};
template
<
>
...
...
@@ -138,12 +165,12 @@ struct vector_type<ushort, 1>
template
<
>
struct
vector_type
<
ushort
,
2
>
{
using
MemoryType
=
ushort2
;
using
MemoryType
=
ushort2
_t
;
union
Data
union
Data
Type
{
MemoryType
vector
;
half
scalar
[
2
];
ushort
scalar
[
2
];
};
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
2
;
}
...
...
@@ -155,17 +182,25 @@ struct vector_type<ushort, 2>
*
(
reinterpret_cast
<
ushort
*>
(
&
v
)
+
I
)
=
s
;
}
__host__
__device__
static
MemoryType
Pack
(
ushort
s0
,
ushort
s1
)
{
DataType
data
;
data
.
scalar
[
0
]
=
s0
;
data
.
scalar
[
1
]
=
s1
;
return
data
.
vector
;
}
};
template
<
>
struct
vector_type
<
ushort
,
4
>
{
typedef
struct
MemoryType
{
ushort2
vector
[
2
];
}
MemoryType
;
using
MemoryType
=
ushort4_t
;
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
4
;
}
union
DataType
{
MemoryType
vector
;
ushort
scalar
[
4
];
};
template
<
index_t
I
>
__host__
__device__
static
void
SetScalar
(
MemoryType
&
v
,
ushort
s
,
Number
<
I
>
)
...
...
@@ -173,6 +208,16 @@ struct vector_type<ushort, 4>
static_assert
(
I
<
4
,
"wrong"
);
*
(
reinterpret_cast
<
ushort
*>
(
&
v
)
+
I
)
=
s
;
}
__host__
__device__
static
MemoryType
Pack
(
ushort
s0
,
ushort
s1
,
ushort
s2
,
ushort
s3
)
{
DataType
data
;
data
.
scalar
[
0
]
=
s0
;
data
.
scalar
[
1
]
=
s1
;
data
.
scalar
[
2
]
=
s2
;
data
.
scalar
[
3
]
=
s3
;
return
data
.
vector
;
}
};
}
// namespace ck
...
...
driver/src/driver.cpp
View file @
d85b89bb
...
...
@@ -801,6 +801,22 @@ int main(int argc, char* argv[])
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
// 1x1 filter, 7x7 image
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
constexpr
index_t
N
=
8
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
4
;
constexpr
index_t
WI
=
4
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
#endif
...
...
@@ -897,7 +913,7 @@ int main(int argc, char* argv[])
if
(
do_verification
)
{
#if
1
#if
0
if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 &&
ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment