Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
32371ea5
Unverified
Commit
32371ea5
authored
Mar 07, 2024
by
Illia Silin
Committed by
GitHub
Mar 07, 2024
Browse files
Merge branch 'develop' into navi3_rel
parents
e42f9ecf
0e28de97
Changes
40
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1785 additions
and
123 deletions
+1785
-123
codegen/test/include/test.hpp
codegen/test/include/test.hpp
+848
-0
codegen/test/rtc/CMakeLists.txt
codegen/test/rtc/CMakeLists.txt
+6
-0
codegen/test/rtc/include/rtc/compile_kernel.hpp
codegen/test/rtc/include/rtc/compile_kernel.hpp
+27
-0
codegen/test/rtc/include/rtc/hip.hpp
codegen/test/rtc/include/rtc/hip.hpp
+78
-0
codegen/test/rtc/include/rtc/kernel.hpp
codegen/test/rtc/include/rtc/kernel.hpp
+62
-0
codegen/test/rtc/include/rtc/manage_ptr.hpp
codegen/test/rtc/include/rtc/manage_ptr.hpp
+55
-0
codegen/test/rtc/include/rtc/tmp_dir.hpp
codegen/test/rtc/include/rtc/tmp_dir.hpp
+24
-0
codegen/test/rtc/src/compile_kernel.cpp
codegen/test/rtc/src/compile_kernel.cpp
+95
-0
codegen/test/rtc/src/hip.cpp
codegen/test/rtc/src/hip.cpp
+102
-0
codegen/test/rtc/src/kernel.cpp
codegen/test/rtc/src/kernel.cpp
+121
-0
codegen/test/rtc/src/tmp_dir.cpp
codegen/test/rtc/src/tmp_dir.cpp
+48
-0
docs/dockerhub.rst
docs/dockerhub.rst
+1
-1
docs/sphinx/requirements.in
docs/sphinx/requirements.in
+1
-1
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+1
-1
example/01_gemm/gemm_xdl_fp8.cpp
example/01_gemm/gemm_xdl_fp8.cpp
+9
-5
example/01_gemm/gemm_xdl_fp8_bf8.cpp
example/01_gemm/gemm_xdl_fp8_bf8.cpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+237
-64
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+29
-24
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+5
-5
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+32
-18
No files found.
codegen/test/include/test.hpp
0 → 100644
View file @
32371ea5
This diff is collapsed.
Click to expand it.
codegen/test/rtc/CMakeLists.txt
0 → 100644
View file @
32371ea5
find_package
(
hip
)
file
(
GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp
)
add_library
(
ck_rtc
${
RTC_SOURCES
}
)
target_include_directories
(
ck_rtc PUBLIC include
)
target_link_libraries
(
ck_rtc PUBLIC hip::host
)
codegen/test/rtc/include/rtc/compile_kernel.hpp
0 → 100644
View file @
32371ea5
#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL
#include <rtc/kernel.hpp>
#include <filesystem>
#include <string>
namespace
rtc
{
struct
src_file
{
std
::
filesystem
::
path
path
;
std
::
string_view
content
;
};
struct
compile_options
{
std
::
string
flags
=
""
;
std
::
string
kernel_name
=
"main"
;
};
kernel
compile_kernel
(
const
std
::
vector
<
src_file
>&
src
,
compile_options
options
=
compile_options
{});
}
// namespace rtc
#endif
codegen/test/rtc/include/rtc/hip.hpp
0 → 100644
View file @
32371ea5
#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP
#include <hip/hip_runtime_api.h>
#include <memory>
#include <string>
namespace
rtc
{
template
<
class
T
>
struct
buffer
{
buffer
()
:
ptr
(),
n
(
0
)
{}
buffer
(
std
::
shared_ptr
<
T
>
p
,
std
::
size_t
sz
)
:
ptr
(
p
),
n
(
sz
)
{}
buffer
(
std
::
shared_ptr
<
void
>
p
,
std
::
size_t
sz
)
:
ptr
(
std
::
reinterpret_pointer_cast
<
T
>
(
p
)),
n
(
sz
)
{
}
explicit
buffer
(
std
::
size_t
sz
)
:
ptr
(
new
T
[
sz
]),
n
(
sz
)
{}
T
*
begin
()
{
return
data
();
}
T
*
end
()
{
return
data
()
+
size
();
}
const
T
*
begin
()
const
{
return
data
();
}
const
T
*
end
()
const
{
return
data
()
+
size
();
}
T
&
front
()
{
return
data
()[
0
];
}
T
&
back
()
{
return
data
()[
size
()
-
1
];
}
T
&
operator
[](
std
::
size_t
i
)
{
return
data
()[
i
];
}
T
&
at
(
std
::
size_t
i
)
{
if
(
i
>=
size
())
throw
std
::
runtime_error
(
"Out of bounds"
);
return
data
()[
i
];
}
const
T
&
front
()
const
{
return
data
()[
0
];
}
const
T
&
back
()
const
{
return
data
()[
size
()
-
1
];
}
const
T
&
operator
[](
std
::
size_t
i
)
const
{
return
data
()[
i
];
}
const
T
&
at
(
std
::
size_t
i
)
const
{
if
(
i
>=
size
())
throw
std
::
runtime_error
(
"Out of bounds"
);
return
data
()[
i
];
}
const
T
*
data
()
const
{
return
ptr
.
get
();
}
T
*
data
()
{
return
ptr
.
get
();
}
std
::
size_t
size
()
const
{
return
n
;
}
std
::
size_t
bytes
()
const
{
return
size
()
*
sizeof
(
T
);
}
bool
empty
()
const
{
return
size
()
==
0
;
}
private:
std
::
shared_ptr
<
T
>
ptr
;
std
::
size_t
n
;
};
std
::
string
get_device_name
();
std
::
string
hip_error
(
int
error
);
std
::
shared_ptr
<
void
>
allocate_gpu
(
std
::
size_t
sz
,
bool
host
=
false
);
std
::
shared_ptr
<
void
>
write_to_gpu
(
const
void
*
x
,
std
::
size_t
sz
,
bool
host
=
false
);
std
::
shared_ptr
<
void
>
read_from_gpu
(
const
void
*
x
,
std
::
size_t
sz
);
template
<
class
T
>
buffer
<
T
>
to_gpu
(
const
buffer
<
T
>&
input
)
{
return
{
write_to_gpu
(
input
.
data
(),
input
.
bytes
()),
input
.
size
()};
}
template
<
class
T
>
buffer
<
T
>
from_gpu
(
const
buffer
<
T
>&
input
)
{
return
{
read_from_gpu
(
input
.
data
(),
input
.
bytes
()),
input
.
size
()};
}
}
// namespace rtc
#endif
codegen/test/rtc/include/rtc/kernel.hpp
0 → 100644
View file @
32371ea5
#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL
#include <hip/hip_runtime_api.h>
#include <memory>
#include <string>
#include <vector>
namespace
rtc
{
struct
kernel_argument
{
template
<
class
T
,
class
U
=
std
::
remove_reference_t
<
T
>,
class
=
std
::
enable_if_t
<
not
std
::
is_base_of
<
kernel_argument
,
T
>
{}
>>
kernel_argument
(
T
&&
x
)
:
size
(
sizeof
(
U
)),
align
(
alignof
(
U
)),
data
(
&
x
)
// NOLINT
{
}
std
::
size_t
size
;
std
::
size_t
align
;
void
*
data
;
};
std
::
vector
<
char
>
pack_args
(
const
std
::
vector
<
kernel_argument
>&
args
);
struct
kernel_impl
;
struct
kernel
{
kernel
()
=
default
;
kernel
(
const
char
*
image
,
const
std
::
string
&
name
);
template
<
class
T
>
kernel
(
const
std
::
vector
<
T
>&
image
,
const
std
::
string
&
name
)
:
kernel
(
reinterpret_cast
<
const
char
*>
(
image
.
data
()),
name
)
{
static_assert
(
sizeof
(
T
)
==
1
,
"Only byte types"
);
}
void
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
,
const
std
::
vector
<
kernel_argument
>&
args
)
const
;
void
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
,
std
::
vector
<
void
*>
args
)
const
;
template
<
class
...
Ts
>
auto
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
,
Ts
...
zs
)
const
{
return
[
=
](
auto
&&
...
xs
)
{
launch
(
stream
,
global
,
local
,
std
::
vector
<
kernel_argument
>
{
xs
...},
zs
...);
};
}
private:
std
::
shared_ptr
<
kernel_impl
>
impl
;
};
}
// namespace rtc
#endif
codegen/test/rtc/include/rtc/manage_ptr.hpp
0 → 100644
View file @
32371ea5
#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER
#include <type_traits>
#include <memory>
namespace
rtc
{
template
<
class
F
,
F
f
>
struct
manage_deleter
{
template
<
class
T
>
void
operator
()(
T
*
x
)
const
{
if
(
x
!=
nullptr
)
{
(
void
)
f
(
x
);
}
}
};
struct
null_deleter
{
template
<
class
T
>
void
operator
()(
T
*
)
const
{
}
};
template
<
class
T
,
class
F
,
F
f
>
using
manage_ptr
=
std
::
unique_ptr
<
T
,
manage_deleter
<
F
,
f
>>
;
template
<
class
T
>
struct
element_type
{
using
type
=
typename
T
::
element_type
;
};
template
<
class
T
>
using
remove_ptr
=
typename
std
::
conditional_t
<
std
::
is_pointer
<
T
>
{},
std
::
remove_pointer
<
T
>
,
element_type
<
T
>>::
type
;
template
<
class
T
>
using
shared
=
std
::
shared_ptr
<
remove_ptr
<
T
>>
;
template
<
class
T
>
shared
<
T
>
share
(
T
p
)
{
return
shared
<
T
>
{
std
::
move
(
p
)};
}
#define RTC_MANAGE_PTR(T, F) rtc::manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F>
}
// namespace rtc
#endif
codegen/test/rtc/include/rtc/tmp_dir.hpp
0 → 100644
View file @
32371ea5
#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR
#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR
#include <string>
#include <filesystem>
namespace
rtc
{
struct
tmp_dir
{
std
::
filesystem
::
path
path
;
tmp_dir
(
const
std
::
string
&
prefix
=
""
);
void
execute
(
const
std
::
string
&
cmd
)
const
;
tmp_dir
(
tmp_dir
const
&
)
=
delete
;
tmp_dir
&
operator
=
(
tmp_dir
const
&
)
=
delete
;
~
tmp_dir
();
};
}
// namespace rtc
#endif
codegen/test/rtc/src/compile_kernel.cpp
0 → 100644
View file @
32371ea5
#include "rtc/hip.hpp"
#include <rtc/compile_kernel.hpp>
#include <rtc/tmp_dir.hpp>
#include <stdexcept>
#include <iostream>
#include <fstream>
#include <cassert>
namespace
rtc
{
template
<
class
T
>
T
generic_read_file
(
const
std
::
string
&
filename
,
size_t
offset
=
0
,
size_t
nbytes
=
0
)
{
std
::
ifstream
is
(
filename
,
std
::
ios
::
binary
|
std
::
ios
::
ate
);
if
(
nbytes
==
0
)
{
// if there is a non-zero offset and nbytes is not set,
// calculate size of remaining bytes to read
nbytes
=
is
.
tellg
();
if
(
offset
>
nbytes
)
throw
std
::
runtime_error
(
"offset is larger than file size"
);
nbytes
-=
offset
;
}
if
(
nbytes
<
1
)
throw
std
::
runtime_error
(
"Invalid size for: "
+
filename
);
is
.
seekg
(
offset
,
std
::
ios
::
beg
);
T
buffer
(
nbytes
,
0
);
if
(
not
is
.
read
(
&
buffer
[
0
],
nbytes
))
throw
std
::
runtime_error
(
"Error reading file: "
+
filename
);
return
buffer
;
}
std
::
vector
<
char
>
read_buffer
(
const
std
::
string
&
filename
,
size_t
offset
=
0
,
size_t
nbytes
=
0
)
{
return
generic_read_file
<
std
::
vector
<
char
>>
(
filename
,
offset
,
nbytes
);
}
std
::
string
read_string
(
const
std
::
string
&
filename
)
{
return
generic_read_file
<
std
::
string
>
(
filename
);
}
void
write_buffer
(
const
std
::
string
&
filename
,
const
char
*
buffer
,
std
::
size_t
size
)
{
std
::
ofstream
os
(
filename
);
os
.
write
(
buffer
,
size
);
}
void
write_buffer
(
const
std
::
string
&
filename
,
const
std
::
vector
<
char
>&
buffer
)
{
write_buffer
(
filename
,
buffer
.
data
(),
buffer
.
size
());
}
void
write_string
(
const
std
::
string
&
filename
,
const
std
::
string_view
&
buffer
)
{
write_buffer
(
filename
,
buffer
.
data
(),
buffer
.
size
());
}
std
::
string
compiler
()
{
return
"/opt/rocm/llvm/bin/clang++ -x hip --cuda-device-only"
;
}
kernel
compile_kernel
(
const
std
::
vector
<
src_file
>&
srcs
,
compile_options
options
)
{
assert
(
not
srcs
.
empty
());
tmp_dir
td
{
"compile"
};
options
.
flags
+=
" -I. -O3"
;
options
.
flags
+=
" -std=c++17"
;
options
.
flags
+=
" --offload-arch="
+
get_device_name
();
std
::
string
out
;
for
(
const
auto
&
src
:
srcs
)
{
std
::
filesystem
::
path
full_path
=
td
.
path
/
src
.
path
;
std
::
filesystem
::
path
parent_path
=
full_path
.
parent_path
();
std
::
filesystem
::
create_directories
(
parent_path
);
write_string
(
full_path
.
string
(),
src
.
content
);
if
(
src
.
path
.
extension
().
string
()
==
".cpp"
)
{
options
.
flags
+=
" -c "
+
src
.
path
.
filename
().
string
();
if
(
out
.
empty
())
out
=
src
.
path
.
stem
().
string
()
+
".o"
;
}
}
options
.
flags
+=
" -o "
+
out
;
td
.
execute
(
compiler
()
+
options
.
flags
);
auto
out_path
=
td
.
path
/
out
;
if
(
not
std
::
filesystem
::
exists
(
out_path
))
throw
std
::
runtime_error
(
"Output file missing: "
+
out
);
auto
obj
=
read_buffer
(
out_path
.
string
());
return
kernel
{
obj
.
data
(),
options
.
kernel_name
};
}
}
// namespace rtc
codegen/test/rtc/src/hip.cpp
0 → 100644
View file @
32371ea5
#include <rtc/hip.hpp>
#include <rtc/manage_ptr.hpp>
#include <stdexcept>
#include <cassert>
namespace
rtc
{
using
hip_ptr
=
RTC_MANAGE_PTR
(
void
,
hipFree
);
std
::
string
hip_error
(
int
error
)
{
return
hipGetErrorString
(
static_cast
<
hipError_t
>
(
error
));
}
int
get_device_id
()
{
int
device
;
auto
status
=
hipGetDevice
(
&
device
);
if
(
status
!=
hipSuccess
)
throw
std
::
runtime_error
(
"No device"
);
return
device
;
}
std
::
string
get_device_name
()
{
hipDeviceProp_t
props
{};
auto
status
=
hipGetDeviceProperties
(
&
props
,
get_device_id
());
if
(
status
!=
hipSuccess
)
throw
std
::
runtime_error
(
"Failed to get device properties"
);
return
props
.
gcnArchName
;
}
bool
is_device_ptr
(
const
void
*
ptr
)
{
hipPointerAttribute_t
attr
;
auto
status
=
hipPointerGetAttributes
(
&
attr
,
ptr
);
if
(
status
!=
hipSuccess
)
return
false
;
return
attr
.
type
==
hipMemoryTypeDevice
;
}
void
gpu_sync
()
{
auto
status
=
hipDeviceSynchronize
();
if
(
status
!=
hipSuccess
)
throw
std
::
runtime_error
(
"hip device synchronization failed: "
+
hip_error
(
status
));
}
std
::
size_t
get_available_gpu_memory
()
{
size_t
free
;
size_t
total
;
auto
status
=
hipMemGetInfo
(
&
free
,
&
total
);
if
(
status
!=
hipSuccess
)
throw
std
::
runtime_error
(
"Failed getting available memory: "
+
hip_error
(
status
));
return
free
;
}
std
::
shared_ptr
<
void
>
allocate_gpu
(
std
::
size_t
sz
,
bool
host
)
{
if
(
sz
>
get_available_gpu_memory
())
throw
std
::
runtime_error
(
"Memory not available to allocate buffer: "
+
std
::
to_string
(
sz
));
void
*
alloc_ptr
=
nullptr
;
auto
status
=
host
?
hipHostMalloc
(
&
alloc_ptr
,
sz
)
:
hipMalloc
(
&
alloc_ptr
,
sz
);
if
(
status
!=
hipSuccess
)
{
if
(
host
)
throw
std
::
runtime_error
(
"Gpu allocation failed: "
+
hip_error
(
status
));
else
return
allocate_gpu
(
sz
,
true
);
}
assert
(
alloc_ptr
!=
nullptr
);
std
::
shared_ptr
<
void
>
result
=
share
(
hip_ptr
{
alloc_ptr
});
return
result
;
}
std
::
shared_ptr
<
void
>
write_to_gpu
(
const
void
*
x
,
std
::
size_t
sz
,
bool
host
)
{
gpu_sync
();
auto
result
=
allocate_gpu
(
sz
,
host
);
assert
(
is_device_ptr
(
result
.
get
()));
assert
(
not
is_device_ptr
(
x
));
auto
status
=
hipMemcpy
(
result
.
get
(),
x
,
sz
,
hipMemcpyHostToDevice
);
if
(
status
!=
hipSuccess
)
throw
std
::
runtime_error
(
"Copy to gpu failed: "
+
hip_error
(
status
));
return
result
;
}
std
::
shared_ptr
<
void
>
read_from_gpu
(
const
void
*
x
,
std
::
size_t
sz
)
{
gpu_sync
();
std
::
shared_ptr
<
char
>
result
(
new
char
[
sz
]);
assert
(
not
is_device_ptr
(
result
.
get
()));
if
(
not
is_device_ptr
(
x
))
{
throw
std
::
runtime_error
(
"read_from_gpu() requires Src buffer to be on the GPU, Copy from gpu failed
\n
"
);
}
auto
status
=
hipMemcpy
(
result
.
get
(),
x
,
sz
,
hipMemcpyDeviceToHost
);
if
(
status
!=
hipSuccess
)
throw
std
::
runtime_error
(
"Copy from gpu failed: "
+
hip_error
(
status
));
// NOLINT
return
std
::
static_pointer_cast
<
void
>
(
result
);
}
}
// namespace rtc
codegen/test/rtc/src/kernel.cpp
0 → 100644
View file @
32371ea5
#include <rtc/kernel.hpp>
#include <rtc/manage_ptr.hpp>
#include <rtc/hip.hpp>
#include <cassert>
// extern declare the function since hip/hip_ext.h header is broken
extern
hipError_t
hipExtModuleLaunchKernel
(
hipFunction_t
,
// NOLINT
uint32_t
,
uint32_t
,
uint32_t
,
uint32_t
,
uint32_t
,
uint32_t
,
size_t
,
hipStream_t
,
void
**
,
void
**
,
hipEvent_t
=
nullptr
,
hipEvent_t
=
nullptr
,
uint32_t
=
0
);
namespace
rtc
{
std
::
vector
<
char
>
pack_args
(
const
std
::
vector
<
kernel_argument
>&
args
)
{
std
::
vector
<
char
>
kernargs
;
for
(
auto
&&
arg
:
args
)
{
std
::
size_t
n
=
arg
.
size
;
const
auto
*
p
=
static_cast
<
const
char
*>
(
arg
.
data
);
// Insert padding
std
::
size_t
padding
=
(
arg
.
align
-
(
kernargs
.
size
()
%
arg
.
align
))
%
arg
.
align
;
kernargs
.
insert
(
kernargs
.
end
(),
padding
,
0
);
kernargs
.
insert
(
kernargs
.
end
(),
p
,
p
+
n
);
}
return
kernargs
;
}
using
hip_module_ptr
=
RTC_MANAGE_PTR
(
hipModule_t
,
hipModuleUnload
);
struct
kernel_impl
{
hip_module_ptr
module
=
nullptr
;
hipFunction_t
fun
=
nullptr
;
};
hip_module_ptr
load_module
(
const
char
*
image
)
{
hipModule_t
raw_m
;
auto
status
=
hipModuleLoadData
(
&
raw_m
,
image
);
hip_module_ptr
m
{
raw_m
};
if
(
status
!=
hipSuccess
)
throw
std
::
runtime_error
(
"Failed to load module: "
+
hip_error
(
status
));
return
m
;
}
kernel
::
kernel
(
const
char
*
image
,
const
std
::
string
&
name
)
:
impl
(
std
::
make_shared
<
kernel_impl
>
())
{
impl
->
module
=
load_module
(
image
);
auto
status
=
hipModuleGetFunction
(
&
impl
->
fun
,
impl
->
module
.
get
(),
name
.
c_str
());
if
(
hipSuccess
!=
status
)
throw
std
::
runtime_error
(
"Failed to get function: "
+
name
+
": "
+
hip_error
(
status
));
}
void
launch_kernel
(
hipFunction_t
fun
,
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
,
void
*
kernargs
,
std
::
size_t
size
)
{
assert
(
global
>
0
);
assert
(
local
>
0
);
void
*
config
[]
=
{
HIP_LAUNCH_PARAM_BUFFER_POINTER
,
kernargs
,
HIP_LAUNCH_PARAM_BUFFER_SIZE
,
&
size
,
HIP_LAUNCH_PARAM_END
};
auto
status
=
hipExtModuleLaunchKernel
(
fun
,
global
,
1
,
1
,
local
,
1
,
1
,
0
,
stream
,
nullptr
,
reinterpret_cast
<
void
**>
(
&
config
),
nullptr
,
nullptr
);
if
(
status
!=
hipSuccess
)
throw
std
::
runtime_error
(
"Failed to launch kernel: "
+
hip_error
(
status
));
}
void
kernel
::
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
,
std
::
vector
<
void
*>
args
)
const
{
assert
(
impl
!=
nullptr
);
void
*
kernargs
=
args
.
data
();
std
::
size_t
size
=
args
.
size
()
*
sizeof
(
void
*
);
launch_kernel
(
impl
->
fun
,
stream
,
global
,
local
,
kernargs
,
size
);
}
void
kernel
::
launch
(
hipStream_t
stream
,
std
::
size_t
global
,
std
::
size_t
local
,
const
std
::
vector
<
kernel_argument
>&
args
)
const
{
assert
(
impl
!=
nullptr
);
std
::
vector
<
char
>
kernargs
=
pack_args
(
args
);
std
::
size_t
size
=
kernargs
.
size
();
launch_kernel
(
impl
->
fun
,
stream
,
global
,
local
,
kernargs
.
data
(),
size
);
}
}
// namespace rtc
\ No newline at end of file
codegen/test/rtc/src/tmp_dir.cpp
0 → 100644
View file @
32371ea5
#include <rtc/tmp_dir.hpp>
#include <algorithm>
#include <random>
#include <thread>
#include <unistd.h>
namespace
rtc
{
std
::
string
random_string
(
std
::
string
::
size_type
length
)
{
static
const
std
::
string
&
chars
=
"0123456789"
"abcdefghijklmnopqrstuvwxyz"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
;
std
::
mt19937
rg
{
std
::
random_device
{}()};
std
::
uniform_int_distribution
<
std
::
string
::
size_type
>
pick
(
0
,
chars
.
length
()
-
1
);
std
::
string
str
(
length
,
0
);
std
::
generate
(
str
.
begin
(),
str
.
end
(),
[
&
]
{
return
chars
[
pick
(
rg
)];
});
return
str
;
}
std
::
string
unique_string
(
const
std
::
string
&
prefix
)
{
auto
pid
=
getpid
();
auto
tid
=
std
::
this_thread
::
get_id
();
auto
clk
=
std
::
chrono
::
steady_clock
::
now
().
time_since_epoch
().
count
();
std
::
stringstream
ss
;
ss
<<
std
::
hex
<<
prefix
<<
"-"
<<
pid
<<
"-"
<<
tid
<<
"-"
<<
clk
<<
"-"
<<
random_string
(
16
);
return
ss
.
str
();
}
tmp_dir
::
tmp_dir
(
const
std
::
string
&
prefix
)
:
path
(
std
::
filesystem
::
temp_directory_path
()
/
unique_string
(
prefix
.
empty
()
?
"ck-rtc"
:
"ck-rtc-"
+
prefix
))
{
std
::
filesystem
::
create_directories
(
this
->
path
);
}
void
tmp_dir
::
execute
(
const
std
::
string
&
cmd
)
const
{
std
::
string
s
=
"cd "
+
path
.
string
()
+
"; "
+
cmd
;
std
::
system
(
s
.
c_str
());
}
tmp_dir
::~
tmp_dir
()
{
std
::
filesystem
::
remove_all
(
this
->
path
);
}
}
// namespace rtc
\ No newline at end of file
docs/dockerhub.rst
View file @
32371ea5
...
@@ -36,7 +36,7 @@ What is inside the image?
...
@@ -36,7 +36,7 @@ What is inside the image?
The docker images have everything you need for running CK including:
The docker images have everything you need for running CK including:
* `ROCm <https://
www
.amd.com/en/
graphics/servers-solutions-rocm
>`_
* `ROCm <https://
rocm.docs
.amd.com/en/
latest/index.html
>`_
* `CMake <https://cmake.org/getting-started/>`_
* `CMake <https://cmake.org/getting-started/>`_
* `Compiler <https://github.com/ROCm/llvm-project>`_
* `Compiler <https://github.com/ROCm/llvm-project>`_
* `Composable Kernel library <https://github.com/ROCm/composable_kernel>`_
* `Composable Kernel library <https://github.com/ROCm/composable_kernel>`_
...
...
docs/sphinx/requirements.in
View file @
32371ea5
rocm-docs-core==0.35.
0
rocm-docs-core==0.35.
1
sphinxcontrib-bibtex==2.6.2
sphinxcontrib-bibtex==2.6.2
docs/sphinx/requirements.txt
View file @
32371ea5
...
@@ -113,7 +113,7 @@ requests==2.31.0
...
@@ -113,7 +113,7 @@ requests==2.31.0
# via
# via
# pygithub
# pygithub
# sphinx
# sphinx
rocm-docs-core==0.35.
0
rocm-docs-core==0.35.
1
# via -r requirements.in
# via -r requirements.in
six==1.16.0
six==1.16.0
# via
# via
...
...
example/01_gemm/gemm_xdl_fp8.cpp
View file @
32371ea5
...
@@ -20,14 +20,18 @@ using BElementOp = PassThrough;
...
@@ -20,14 +20,18 @@ using BElementOp = PassThrough;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
LoopSched
=
ck
::
make_default_loop_scheduler
();
static
constexpr
auto
PipelineVer
=
ck
::
PipelineVersion
::
v1
;
using
ComputeTypeA
=
ck
::
f8_t
;
using
ComputeTypeB
=
ck
::
f8_t
;
// clang-format off
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
Loop| Pipeline| Compute| Compute|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
Scheduler| Version| TypeA| TypeB|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
| | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
>
;
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
LoopSched
,
PipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
example/01_gemm/gemm_xdl_fp8_bf8.cpp
View file @
32371ea5
...
@@ -27,10 +27,10 @@ using ComputeTypeB = ck::bf8_t;
...
@@ -27,10 +27,10 @@ using ComputeTypeB = ck::bf8_t;
// clang-format off
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
Loop| Pipeline| Compute| Compute|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
Scheduler| Version| TypeA| TypeB|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
| | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
LoopSched
,
PipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
64
,
1
,
4
>
,
8
,
LoopSched
,
PipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
32371ea5
...
@@ -498,22 +498,15 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -498,22 +498,15 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
};
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
)
{
if
(
!
ck
::
is_xdl_supported
())
{
{
return
false
;
}
// check vector load/store
// check vector load/store
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
{
if
(
arg
.
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
}
}
...
@@ -521,7 +514,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -521,7 +514,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
{
// FIXME: not rigorous
// FIXME: not rigorous
if
(
arg
.
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
}
}
...
@@ -530,11 +523,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -530,11 +523,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{
{
return
false
;
return
false
;
}
}
// check vector laod of B
// check vector laod of B
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
{
if
(
arg
.
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
}
}
...
@@ -542,7 +534,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -542,7 +534,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
{
// FIXME: not rigorous
// FIXME: not rigorous
if
(
arg
.
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
}
}
...
@@ -574,7 +566,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -574,7 +566,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
// only support RowMajor for now
// only support RowMajor for now
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
{
{
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
if
(
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
{
return
false
;
return
false
;
}
}
...
@@ -583,9 +575,18 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -583,9 +575,18 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{
{
return
false
;
return
false
;
}
}
return
true
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
return
IsSupported
(
arg
.
MRaw_
,
arg
.
NRaw_
,
arg
.
KRaw_
)
and
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
...
@@ -708,6 +709,178 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -708,6 +709,178 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return
str
.
str
();
return
str
.
str
();
}
}
template
<
class
ADesc
,
class
BDesc
,
class
DsDesc
,
class
EDesc
>
struct
Descriptor
{
static
constexpr
auto
ds_tuple
()
{
return
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
DsDesc
{});
}
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
ADesc
{}))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
BDesc
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ds_tuple
())
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{}))
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
ADesc
{})))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
BDesc
{})))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_tuple
()))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{})))
>
;
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{})))
>
;
// tensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k
;
BGridDesc_N_K
b_grid_desc_n_k
;
DsGridDesc_M_N
ds_grid_desc_m_n
;
EGridDesc_M_N
e_grid_desc_m_n
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map
;
// element-wise op
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
CDEElementwiseOperation
cde_element_op
;
// for checking vector load/store
index_t
MRaw
;
index_t
NRaw
;
index_t
KRaw
;
bool
has_main_k_block_loop
=
true
;
constexpr
Descriptor
(
ADesc
a
,
BDesc
b
,
DsDesc
ds
,
EDesc
e
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CDEElementwiseOperation
cde_element_op_
)
:
a_grid_desc_m_k
{
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a
)},
b_grid_desc_n_k
{
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b
)},
ds_grid_desc_m_n
{
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
ds
)},
e_grid_desc_m_n
{
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
e
)},
a_grid_desc_ak0_m_ak1
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k
)},
b_grid_desc_bk0_n_bk1
{
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
ds
))},
e_grid_desc_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
)},
block_2_etile_map
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n
)},
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
a_element_op
{
a_element_op_
},
b_element_op
{
b_element_op_
},
cde_element_op
{
cde_element_op_
},
MRaw
{
e
.
GetLength
(
I0
)},
NRaw
{
e
.
GetLength
(
I1
)},
KRaw
{
a
.
GetLength
(
I1
)}
{
}
constexpr
bool
IsValid
()
const
{
return
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
block_2_etile_map
)
and
IsSupported
(
MRaw
,
NRaw
,
KRaw
);
}
constexpr
index_t
GetBlockSize
()
const
{
return
BlockSize
;
}
constexpr
index_t
GetGridSize
()
const
{
return
block_2_etile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
}
};
template
<
class
ADesc
,
class
BDesc
,
class
DsDesc
,
class
EDesc
>
static
constexpr
auto
make_descriptor
(
ADesc
a
,
BDesc
b
,
DsDesc
ds
,
EDesc
e
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
CDEElementwiseOperation
cde_element_op
=
CDEElementwiseOperation
{})
{
return
Descriptor
<
ADesc
,
BDesc
,
DsDesc
,
EDesc
>
(
a
,
b
,
ds
,
e
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
template
<
class
Desc
,
class
DsPointer
>
__device__
static
void
Run
(
const
Desc
&
desc
,
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
)
{
__shared__
char
p_shared_block
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
assert
(
desc
.
IsValid
());
if
(
desc
.
has_main_k_block_loop
)
{
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
desc
.
cde_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_etile_map
);
}
else
{
GridwiseGemm
::
template
Run
<
false
>(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
desc
.
cde_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_etile_map
);
}
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
32371ea5
...
@@ -24,9 +24,9 @@ struct BlockToCTileMap_M00_N0_M01
...
@@ -24,9 +24,9 @@ struct BlockToCTileMap_M00_N0_M01
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
__host__
__device__
BlockToCTileMap_M00_N0_M01
()
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01
()
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
1
)
index_t
M01
=
1
)
:
M01_
(
M01
),
underlying_map_
(
GetBlockToCTileMap
(
c_grid_desc_m_n
,
M01
))
:
M01_
(
M01
),
underlying_map_
(
GetBlockToCTileMap
(
c_grid_desc_m_n
,
M01
))
{
{
...
@@ -51,7 +51,7 @@ struct BlockToCTileMap_M00_N0_M01
...
@@ -51,7 +51,7 @@ struct BlockToCTileMap_M00_N0_M01
}
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
__host__
__device__
constexpr
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
const
CTileDim
&
c_tile_dim
)
const
{
{
if
constexpr
(
DeviceCTileIndexCheck
)
if
constexpr
(
DeviceCTileIndexCheck
)
...
@@ -60,7 +60,7 @@ struct BlockToCTileMap_M00_N0_M01
...
@@ -60,7 +60,7 @@ struct BlockToCTileMap_M00_N0_M01
return
true
;
return
true
;
}
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
{
if
constexpr
(
DeviceCTileIndexCheck
)
if
constexpr
(
DeviceCTileIndexCheck
)
return
true
;
// validity check moved to kernel
return
true
;
// validity check moved to kernel
...
@@ -120,18 +120,19 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -120,18 +120,19 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
__host__
__device__
const
expr
BlockToCTileMap_M00_N0_M01Adapt
(
default
;
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
default
;
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
&
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
operator
=
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
operator
=
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
&
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
operator
=
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
operator
=
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
{
{
#if 0
#if 0
...
@@ -142,7 +143,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -142,7 +143,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
)
index_t
M01
=
8
)
:
BlockToCTileMap_M00_N0_M01Adapt
(
:
BlockToCTileMap_M00_N0_M01Adapt
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
...
@@ -164,7 +166,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -164,7 +166,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
{
return
true
;
return
true
;
}
}
...
@@ -237,7 +239,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -237,7 +239,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
}
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
__host__
__device__
constexpr
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
const
CTileDim
&
/* c_tile_dim */
)
const
{
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
...
@@ -616,7 +618,10 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
...
@@ -616,7 +618,10 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
private:
private:
index_t
M01_
;
index_t
M01_
;
...
@@ -674,7 +679,7 @@ struct BlockToCTileMap_M00_N00_M01_N01
...
@@ -674,7 +679,7 @@ struct BlockToCTileMap_M00_N00_M01_N01
return
true
;
return
true
;
}
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
{
if
constexpr
(
DeviceCTileIndexCheck
)
if
constexpr
(
DeviceCTileIndexCheck
)
return
true
;
// validity check moved to kernel
return
true
;
// validity check moved to kernel
...
@@ -786,7 +791,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
...
@@ -786,7 +791,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
return
true
;
return
true
;
}
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
{
if
constexpr
(
DeviceCTileIndexCheck
)
if
constexpr
(
DeviceCTileIndexCheck
)
return
true
;
// validity check moved to kernel
return
true
;
// validity check moved to kernel
...
@@ -910,7 +915,7 @@ struct OffsettedBlockToCTileMap
...
@@ -910,7 +915,7 @@ struct OffsettedBlockToCTileMap
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
{
return
block_to_ctile_map_
.
CheckValidity
(
c_grid_desc_m_n
);
return
block_to_ctile_map_
.
CheckValidity
(
c_grid_desc_m_n
);
}
}
...
@@ -967,7 +972,7 @@ struct BlockToCTileMap_3DGrid_KSplit
...
@@ -967,7 +972,7 @@ struct BlockToCTileMap_3DGrid_KSplit
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
{
return
true
;
return
true
;
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
32371ea5
...
@@ -264,7 +264,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -264,7 +264,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
const
Block2ETileMap
&
)
{
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
...
@@ -310,10 +310,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -310,10 +310,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
}
// check block-to-E-tile
// check block-to-E-tile
if
(
!
block_2_etile_map
.
CheckValidity
(
e_grid_desc_m_n
))
//
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{
//
{
return
false
;
//
return false;
}
//
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
// check tensor size: cannot be larger than 2GB each
...
...
include/ck/utility/type_convert.hpp
View file @
32371ea5
...
@@ -166,9 +166,6 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
...
@@ -166,9 +166,6 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
float
max_fp8
=
240.0
f
;
if
(
!
std
::
isinf
(
x
))
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
#if defined(__gfx94__)
#if defined(__gfx94__)
union
union
{
{
...
@@ -178,6 +175,11 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
...
@@ -178,6 +175,11 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
}
val
;
}
val
;
val
.
fval
=
x
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
uint32_t
ival
=
0
;
const
float
max_fp8
=
240.0
f
;
// if x is not +/- infinity or nan
if
((
val
.
i32val
&
NumericUtils
<
float
>::
nan_mask
)
!=
NumericUtils
<
float
>::
Inf
)
// clip float value
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
max_fp8
,
-
max_fp8
);
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
return
val
.
i8val
[
0
];
// little endian
...
@@ -225,6 +227,11 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
...
@@ -225,6 +227,11 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
}
val
;
}
val
;
val
.
fval
=
x
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
uint32_t
ival
=
0
;
const
float
max_bf8
=
57344.0
f
;
// if x is not +/- infinity or nan
if
((
val
.
i32val
&
NumericUtils
<
float
>::
nan_mask
)
!=
NumericUtils
<
float
>::
Inf
)
// clip float value
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
max_bf8
,
-
max_bf8
);
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
return
val
.
i8val
[
0
];
// little endian
...
@@ -265,9 +272,6 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
...
@@ -265,9 +272,6 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
template
<
>
template
<
>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
float
>
(
float
x
)
{
{
float
max_fp8
=
240.0
f
;
if
(
!
std
::
isinf
(
x
))
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
#if defined(__gfx94__)
#if defined(__gfx94__)
union
union
{
{
...
@@ -277,6 +281,11 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
...
@@ -277,6 +281,11 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
}
val
;
}
val
;
val
.
fval
=
x
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
uint32_t
ival
=
0
;
const
float
max_fp8
=
240.0
f
;
// if x is not +/- infinity or nan
if
((
val
.
i32val
&
NumericUtils
<
float
>::
nan_mask
)
!=
NumericUtils
<
float
>::
Inf
)
// clip float value
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
max_fp8
,
-
max_fp8
);
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
return
val
.
i8val
[
0
];
...
@@ -322,6 +331,11 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
...
@@ -322,6 +331,11 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
}
val
;
}
val
;
val
.
fval
=
x
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
uint32_t
ival
=
0
;
const
float
max_bf8
=
57344.0
f
;
// if x is not +/- infinity or nan
if
((
val
.
i32val
&
NumericUtils
<
float
>::
nan_mask
)
!=
NumericUtils
<
float
>::
Inf
)
// clip float value
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
max_bf8
,
-
max_bf8
);
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
return
val
.
i8val
[
0
];
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment