Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5e98fc5b
"example/vscode:/vscode.git/clone" did not exist on "12235112a10ecbe47acead9a03564cb42c4624c2"
Commit
5e98fc5b
authored
Sep 07, 2023
by
Jing Zhang
Browse files
fixed fp8 init; and reference gemm
parent
37a8c1f7
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
35 additions
and
18 deletions
+35
-18
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+11
-5
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+4
-5
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+5
-5
library/include/ck/library/utility/host_tensor_generator.hpp
library/include/ck/library/utility/host_tensor_generator.hpp
+12
-0
profiler/include/profiler/profile_gemm_multiply_add_impl.hpp
profiler/include/profiler/profile_gemm_multiply_add_impl.hpp
+2
-2
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
5e98fc5b
...
...
@@ -247,7 +247,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
5e98fc5b
...
...
@@ -27,6 +27,12 @@ struct PassThrough
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
...
...
@@ -412,14 +418,14 @@ struct Swish
{
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
double
>::
value
||
is_same
<
X
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
/
(
ck
::
type_convert
<
T
>
(
1
)
+
ck
::
math
::
exp
(
-
beta_
*
x
));
y
=
x
/
(
ck
::
type_convert
<
Y
>
(
1
)
+
ck
::
math
::
exp
(
-
beta_
*
x
));
};
float
beta_
=
1.0
f
;
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
5e98fc5b
...
...
@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
Src
Data
v
;
Dst
Data
v
;
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply type convert
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
);
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
v
;
});
const
bool
is_dst_valid
=
...
...
@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
Src
Data
v
;
Dst
Data
v
;
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v
)
;
dst_buf
(
Number
<
dst_offset
>
{})
=
v
;
});
});
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
5e98fc5b
...
...
@@ -20,7 +20,8 @@ template <typename ADataType,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
,
typename
ComputeType
=
ADataType
>
struct
ReferenceGemm
:
public
device
::
BaseOperator
{
// Argument
...
...
@@ -64,8 +65,8 @@ struct ReferenceGemm : public device::BaseOperator
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
AData
Type
v_a
;
BData
Type
v_b
;
Compute
Type
v_a
;
Compute
Type
v_b
;
// use PassThrough instead of ConvertBF16RTN for reference calculation
if
constexpr
(
is_same_v
<
AElementwiseOperation
,
...
...
@@ -88,8 +89,7 @@ struct ReferenceGemm : public device::BaseOperator
arg
.
b_element_op_
(
v_b
,
arg
.
b_k_n_
(
k
,
n
));
}
v_acc
+=
ck
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck
::
type_convert
<
AccDataType
>
(
v_b
);
v_acc
+=
type_convert
<
AccDataType
>
(
v_a
*
v_b
);
}
CDataType
v_c
;
...
...
library/include/ck/library/utility/host_tensor_generator.hpp
View file @
5e98fc5b
...
...
@@ -55,6 +55,18 @@ struct GeneratorTensor_1<int8_t>
}
};
template
<
>
struct
GeneratorTensor_1
<
ck
::
f8_t
>
{
float
value
=
1.0
;
template
<
typename
...
Is
>
ck
::
f8_t
operator
()(
Is
...)
{
return
ck
::
type_convert
<
ck
::
f8_t
>
(
value
);
}
};
template
<
typename
T
>
struct
GeneratorTensor_2
{
...
...
profiler/include/profiler/profile_gemm_multiply_add_impl.hpp
View file @
5e98fc5b
...
...
@@ -83,8 +83,8 @@ bool profile_gemm_multiply_add_impl(int do_verification,
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
-
1
,
1
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.
5
,
0.
5
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
0.2
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.
1
,
0.
1
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
0.0
,
1.0
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D1DataType
>
{
0.0
,
1.0
});
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment