Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
liupw
numba-DTK
Commits
3e5f428e
Commit
3e5f428e
authored
Apr 06, 2024
by
dugupeiwen
Browse files
Remove use of llvmlite.llvmpy for 0.58
parent
5be111ee
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
103 additions
and
57 deletions
+103
-57
numba/roc/codegen.py
numba/roc/codegen.py
+44
-6
numba/roc/compiler.py
numba/roc/compiler.py
+4
-4
numba/roc/dispatch.py
numba/roc/dispatch.py
+4
-2
numba/roc/hsaimpl.py
numba/roc/hsaimpl.py
+18
-17
numba/roc/target.py
numba/roc/target.py
+33
-28
No files found.
numba/roc/codegen.py
View file @
3e5f428e
from
llvmlite
import
binding
as
ll
from
llvmlite.llvmpy
import
core
as
lc
# from llvmlite.llvmpy import core as lc
import
llvmlite.ir
as
llvmir
from
numba.core
import
utils
from
numba.core.codegen
import
BaseCPU
Codegen
,
CodeLibrary
from
numba.core.codegen
import
Codegen
,
CodeLibrary
,
CPUCodeLibrary
from
.hlc
import
DATALAYOUT
,
TRIPLE
,
hlc
class
HSACodeLibrary
(
CodeLibrary
):
class
HSACodeLibrary
(
CPUCodeLibrary
):
def
_optimize_functions
(
self
,
ll_module
):
pass
...
...
@@ -25,17 +25,55 @@ class HSACodeLibrary(CodeLibrary):
return
str
(
out
.
hsail
)
class
JITHSACodegen
(
BaseCPUCodegen
):
# class JITHSACodegen(Codegen):
# _library_class = HSACodeLibrary
# def _init(self, llvm_module):
# assert list(llvm_module.global_variables) == [], "Module isn't empty"
# self._data_layout = DATALAYOUT[utils.MACHINE_BITS]
# self._target_data = ll.create_target_data(self._data_layout)
# def _create_empty_module(self, name):
# ir_module = llvmir.Module(name)
# ir_module.triple = TRIPLE
# return ir_module
# def _module_pass_manager(self):
# raise NotImplementedError
# def _function_pass_manager(self, llvm_module):
# raise NotImplementedError
# def _add_module(self, module):
# pass
class
JITHSACodegen
(
Codegen
):
_library_class
=
HSACodeLibrary
def
__init__
(
self
,
module_name
):
# initialize_llvm()
ll
.
initialize
()
ll
.
initialize_native_target
()
ll
.
initialize_native_asmprinter
()
self
.
_data_layout
=
None
self
.
_llvm_module
=
ll
.
parse_assembly
(
str
(
self
.
_create_empty_module
(
module_name
)))
self
.
_llvm_module
.
name
=
"global_codegen_module"
# self._rtlinker = RuntimeLinker()
self
.
_init
(
self
.
_llvm_module
)
def
_init
(
self
,
llvm_module
):
assert
list
(
llvm_module
.
global_variables
)
==
[],
"Module isn't empty"
self
.
_data_layout
=
DATALAYOUT
[
utils
.
MACHINE_BITS
]
self
.
_target_data
=
ll
.
create_target_data
(
self
.
_data_layout
)
def
_create_empty_module
(
self
,
name
):
ir_module
=
l
c
.
Module
(
name
)
ir_module
=
l
lvmir
.
Module
(
name
)
ir_module
.
triple
=
TRIPLE
if
self
.
_data_layout
:
ir_module
.
data_layout
=
self
.
_data_layout
return
ir_module
def
_module_pass_manager
(
self
):
...
...
numba/roc/compiler.py
View file @
3e5f428e
...
...
@@ -26,10 +26,10 @@ def compile_hsa(pyfunc, return_type, args, debug):
# TODO handle debug flag
flags
=
compiler
.
Flags
()
# Do not compile (generate native code), just lower (to LLVM)
flags
.
set
(
'
no_compile
'
)
flags
.
set
(
'
no_cpython_wrapper
'
)
flags
.
set
(
'
no_cfunc_wrapper
'
)
flags
.
unset
(
'nrt'
)
flags
.
no_compile
=
True
flags
.
no_cpython_wrapper
=
True
flags
.
no_cfunc_wrapper
=
True
flags
.
nrt
=
False
# Run compilation pipeline
cres
=
compiler
.
compile_extra
(
typingctx
=
typingctx
,
targetctx
=
targetctx
,
...
...
numba/roc/dispatch.py
View file @
3e5f428e
import
numpy
as
np
from
numba.np.ufunc.deviceufunc
import
(
UFuncMechanism
,
GenerializedUFunc
,
# from numba.np.ufunc.deviceufunc import (UFuncMechanism, GenerializedUFunc,
# GUFuncCallSteps)
from
numba.np.ufunc.deviceufunc
import
(
UFuncMechanism
,
GeneralizedUFunc
,
GUFuncCallSteps
)
from
numba.roc.hsadrv.driver
import
dgpu_present
import
numba.roc.hsadrv.devicearray
as
devicearray
...
...
@@ -119,7 +121,7 @@ class _HsaGUFuncCallSteps(GUFuncCallSteps):
kernel
.
configure
(
nelem
,
min
(
nelem
,
64
))(
*
args
)
class
HSAGenerializedUFunc
(
Gener
i
alizedUFunc
):
class
HSAGenerializedUFunc
(
GeneralizedUFunc
):
@
property
def
_call_steps
(
self
):
return
_HsaGUFuncCallSteps
...
...
numba/roc/hsaimpl.py
View file @
3e5f428e
import
operator
from
functools
import
reduce
from
llvmlite.llvmpy.core
import
Type
import
llvmlite.llvmpy.core
as
lc
#
from llvmlite.llvmpy.core import Type
#
import llvmlite.llvmpy.core as lc
import
llvmlite.binding
as
ll
from
llvmlite
import
ir
from
numba
import
roc
from
numba.core.imputils
import
Registry
from
numba.core
import
types
,
cgutils
from
numba.core.itanium_mangler
import
mangle_c
,
mangle
,
mangle_type
# from numba.core.itanium_mangler import mangle_c, mangle, mangle_type
from
numba.core.itanium_mangler
import
mangle
,
mangle_type
from
numba.core.typing.npydecl
import
parse_dtype
from
numba.roc
import
target
from
numba.roc
import
stubs
...
...
@@ -19,13 +20,13 @@ from numba.roc import enums
registry
=
Registry
()
lower
=
registry
.
lower
_void_value
=
lc
.
Constant
.
null
(
lc
.
Type
.
pointer
(
lc
.
Type
.
int
(
8
))
)
_void_value
=
ir
.
Constant
(
ir
.
PointerType
(
ir
.
IntType
(
8
)),
None
)
# -----------------------------------------------------------------------------
def
_declare_function
(
context
,
builder
,
name
,
sig
,
cargs
,
mangler
=
mangle
_c
):
mangler
=
mangle
):
"""Insert declaration for a opencl builtin function.
Uses the Itanium mangler.
...
...
@@ -50,11 +51,11 @@ def _declare_function(context, builder, name, sig, cargs,
"""
mod
=
builder
.
module
if
sig
.
return_type
==
types
.
void
:
llretty
=
lc
.
Type
.
void
()
llretty
=
ir
.
VoidType
()
else
:
llretty
=
context
.
get_value_type
(
sig
.
return_type
)
llargs
=
[
context
.
get_value_type
(
t
)
for
t
in
sig
.
args
]
fnty
=
Type
.
f
unction
(
llretty
,
llargs
)
fnty
=
ir
.
F
unction
Type
(
llretty
,
llargs
)
mangled
=
mangler
(
name
,
cargs
)
fn
=
mod
.
get_or_insert_function
(
fnty
,
mangled
)
fn
.
calling_convention
=
target
.
CC_SPIR_FUNC
...
...
@@ -154,7 +155,7 @@ def mem_fence_impl(context, builder, sig, args):
@
lower
(
stubs
.
wavebarrier
)
def
wavebarrier_impl
(
context
,
builder
,
sig
,
args
):
assert
not
args
fnty
=
Type
.
f
unction
(
Type
.
void
(),
[])
fnty
=
ir
.
F
unctionType
(
ir
.
VoidType
(),
[])
fn
=
builder
.
module
.
declare_intrinsic
(
'llvm.amdgcn.wave.barrier'
,
fnty
=
fnty
)
builder
.
call
(
fn
,
[])
return
_void_value
...
...
@@ -166,12 +167,12 @@ def activelanepermute_wavewidth_impl(context, builder, sig, args):
assert
sig
.
args
[
0
]
==
sig
.
args
[
2
]
elem_type
=
sig
.
args
[
0
]
bitwidth
=
elem_type
.
bitwidth
intbitwidth
=
Type
.
int
(
bitwidth
)
i32
=
Type
.
int
(
32
)
i1
=
Type
.
int
(
1
)
intbitwidth
=
ir
.
Int
Type
(
bitwidth
)
i32
=
ir
.
Int
Type
(
32
)
i1
=
ir
.
Int
Type
(
1
)
name
=
"__hsail_activelanepermute_wavewidth_b{0}"
.
format
(
bitwidth
)
fnty
=
Type
.
f
unction
(
intbitwidth
,
[
intbitwidth
,
i32
,
intbitwidth
,
i1
])
fnty
=
ir
.
F
unction
Type
(
intbitwidth
,
[
intbitwidth
,
i32
,
intbitwidth
,
i1
])
fn
=
builder
.
module
.
get_or_insert_function
(
fnty
,
name
=
name
)
fn
.
calling_convention
=
target
.
CC_SPIR_FUNC
...
...
@@ -188,14 +189,14 @@ def _gen_ds_permute(intrinsic_name):
"""
assert
sig
.
return_type
==
sig
.
args
[
1
]
idx
,
src
=
args
i32
=
Type
.
int
(
32
)
fnty
=
Type
.
f
unction
(
i32
,
[
i32
,
i32
])
i32
=
ir
.
Int
Type
(
32
)
fnty
=
ir
.
F
unction
Type
(
i32
,
[
i32
,
i32
])
fn
=
builder
.
module
.
declare_intrinsic
(
intrinsic_name
,
fnty
=
fnty
)
# the args are byte addressable, VGPRs are 4 wide so mul idx by 4
# the idx might be an int64, this is ok to trunc to int32 as
# wavefront_size is never likely overflow an int32
idx
=
builder
.
trunc
(
idx
,
i32
)
four
=
lc
.
Constant
.
int
(
i32
,
4
)
four
=
ir
.
Constant
(
i32
,
4
)
idx
=
builder
.
mul
(
idx
,
four
)
# bit cast is so float32 works as packed i32, the return casts back
result
=
builder
.
call
(
fn
,
(
idx
,
builder
.
bitcast
(
src
,
i32
)))
...
...
@@ -258,7 +259,7 @@ def hsail_smem_alloc_array_tuple(context, builder, sig, args):
def
_generic_array
(
context
,
builder
,
shape
,
dtype
,
symbol_name
,
addrspace
):
elemcount
=
reduce
(
operator
.
mul
,
shape
,
1
)
lldtype
=
context
.
get_data_type
(
dtype
)
laryty
=
Type
.
array
(
lldtype
,
elemcount
)
laryty
=
ir
.
ArrayType
(
lldtype
,
elemcount
)
if
addrspace
==
target
.
SPIR_LOCAL_ADDRSPACE
:
lmod
=
builder
.
module
...
...
@@ -269,7 +270,7 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace):
if
elemcount
<=
0
:
raise
ValueError
(
"array length <= 0"
)
else
:
gvmem
.
linkage
=
lc
.
LINKAGE_INTERNAL
gvmem
.
linkage
=
'internal'
if
dtype
not
in
types
.
number_domain
:
raise
TypeError
(
"unsupported type: %s"
%
dtype
)
...
...
numba/roc/target.py
View file @
3e5f428e
import
re
from
llvmlite.llvmpy
import
core
as
lc
from
llvmlite
import
ir
as
llvmir
# from llvmlite.llvmpy import core as lc
# from llvmlite import ir as llvmir
from
llvmlite
import
ir
from
llvmlite
import
binding
as
ll
from
numba.core
import
typing
,
types
,
utils
,
datamodel
,
cgutils
from
numba.core.utils
import
cached_property
# from numba.core.utils import cached_property
from
functools
import
cached_property
from
numba.core.base
import
BaseContext
from
numba.core.callconv
import
MinimalCallConv
from
numba.roc
import
codegen
...
...
@@ -65,6 +67,9 @@ class HSATargetContext(BaseContext):
implement_powi_as_math_call
=
True
generic_addrspace
=
SPIR_GENERIC_ADDRSPACE
def
__init__
(
self
,
typingctx
,
target
=
'ROCm'
):
super
().
__init__
(
typingctx
,
target
)
def
init
(
self
):
self
.
_internal_codegen
=
codegen
.
JITHSACodegen
(
"numba.hsa.jit"
)
self
.
_target_data
=
\
...
...
@@ -89,7 +94,7 @@ class HSATargetContext(BaseContext):
def
target_data
(
self
):
return
self
.
_target_data
def
mangler
(
self
,
name
,
argtypes
):
def
mangler
(
self
,
name
,
argtypes
,
*
,
abi_tags
=
(),
uid
=
None
):
def
repl
(
m
):
ch
=
m
.
group
(
0
)
return
"_%X_"
%
ord
(
ch
)
...
...
@@ -119,7 +124,7 @@ class HSATargetContext(BaseContext):
arginfo
=
self
.
get_arg_packer
(
argtypes
)
def
sub_gen_with_global
(
lty
):
if
isinstance
(
lty
,
llvm
ir
.
PointerType
):
if
isinstance
(
lty
,
ir
.
PointerType
):
return
(
lty
.
pointee
.
as_pointer
(
SPIR_GLOBAL_ADDRSPACE
),
lty
.
addrspace
)
return
lty
,
None
...
...
@@ -129,13 +134,13 @@ class HSATargetContext(BaseContext):
arginfo
.
argument_types
))
else
:
llargtys
=
changed
=
()
wrapperfnty
=
lc
.
Type
.
function
(
lc
.
Type
.
void
(),
llargtys
)
wrapperfnty
=
ir
.
FunctionType
(
ir
.
VoidType
(),
llargtys
)
wrapper_module
=
self
.
create_module
(
"hsa.kernel.wrapper"
)
wrappername
=
'hsaPy_{name}'
.
format
(
name
=
func
.
name
)
argtys
=
list
(
arginfo
.
argument_types
)
fnty
=
lc
.
Type
.
function
(
lc
.
Type
.
int
(),
fnty
=
ir
.
FunctionType
(
ir
.
IntType
(),
[
self
.
call_conv
.
get_return_type
(
types
.
pyobject
)]
+
argtys
)
...
...
@@ -144,7 +149,7 @@ class HSATargetContext(BaseContext):
wrapper
=
wrapper_module
.
add_function
(
wrapperfnty
,
name
=
wrappername
)
builder
=
lc
.
Builder
(
wrapper
.
append_basic_block
(
''
))
builder
=
ir
.
IR
Builder
(
wrapper
.
append_basic_block
(
''
))
# Adjust address space of each kernel argument
fixed_args
=
[]
...
...
@@ -193,7 +198,7 @@ class HSATargetContext(BaseContext):
"""
Handle addrspacecast
"""
ptras
=
llvm
ir
.
PointerType
(
src
.
type
.
pointee
,
addrspace
=
addrspace
)
ptras
=
ir
.
PointerType
(
src
.
type
.
pointee
,
addrspace
=
addrspace
)
return
builder
.
addrspacecast
(
src
,
ptras
)
...
...
@@ -213,7 +218,7 @@ def set_hsa_kernel(fn):
# Mark kernels
ocl_kernels
=
mod
.
get_or_insert_named_metadata
(
"opencl.kernels"
)
ocl_kernels
.
add
(
lc
.
M
eta
D
ata
.
get
(
mod
,
[
fn
,
ocl_kernels
.
add
(
ir
.
Module
.
add_m
eta
d
ata
(
mod
,
[
fn
,
gen_arg_addrspace_md
(
fn
),
gen_arg_access_qual_md
(
fn
),
gen_arg_type
(
fn
),
...
...
@@ -221,16 +226,16 @@ def set_hsa_kernel(fn):
gen_arg_base_type
(
fn
)]))
# SPIR version 2.0
make_constant
=
lambda
x
:
lc
.
Constant
.
int
(
lc
.
Type
.
int
(),
x
)
make_constant
=
lambda
x
:
ir
.
Constant
(
ir
.
IntType
(),
x
)
spir_version_constant
=
[
make_constant
(
x
)
for
x
in
SPIR_VERSION
]
spir_version
=
mod
.
get_or_insert_named_metadata
(
"opencl.spir.version"
)
if
not
spir_version
.
operands
:
spir_version
.
add
(
lc
.
M
eta
D
ata
.
get
(
mod
,
spir_version_constant
))
spir_version
.
add
(
ir
.
Module
.
add_m
eta
d
ata
(
mod
,
spir_version_constant
))
ocl_version
=
mod
.
get_or_insert_named_metadata
(
"opencl.ocl.version"
)
if
not
ocl_version
.
operands
:
ocl_version
.
add
(
lc
.
M
eta
D
ata
.
get
(
mod
,
spir_version_constant
))
ocl_version
.
add
(
ir
.
Module
.
add_m
eta
d
ata
(
mod
,
spir_version_constant
))
## The following metadata does not seem to be necessary
# Other metadata
...
...
@@ -259,9 +264,9 @@ def gen_arg_addrspace_md(fn):
else
:
codes
.
append
(
SPIR_PRIVATE_ADDRSPACE
)
consts
=
[
lc
.
Constant
.
int
(
lc
.
Type
.
int
(),
x
)
for
x
in
codes
]
name
=
lc
.
MetaDataString
.
get
(
mod
,
"kernel_arg_addr_space"
)
return
lc
.
M
eta
D
ata
.
get
(
mod
,
[
name
]
+
consts
)
consts
=
[
ir
.
Constant
(
ir
.
IntType
(),
x
)
for
x
in
codes
]
name
=
ir
.
MetaDataString
(
mod
,
"kernel_arg_addr_space"
)
return
ir
.
Module
.
add_m
eta
d
ata
(
mod
,
[
name
]
+
consts
)
def
gen_arg_access_qual_md
(
fn
):
...
...
@@ -269,9 +274,9 @@ def gen_arg_access_qual_md(fn):
Generate kernel_arg_access_qual metadata
"""
mod
=
fn
.
module
consts
=
[
lc
.
MetaDataString
.
get
(
mod
,
"none"
)]
*
len
(
fn
.
args
)
name
=
lc
.
MetaDataString
.
get
(
mod
,
"kernel_arg_access_qual"
)
return
lc
.
M
eta
D
ata
.
get
(
mod
,
[
name
]
+
consts
)
consts
=
[
ir
.
MetaDataString
(
mod
,
"none"
)]
*
len
(
fn
.
args
)
name
=
ir
.
MetaDataString
(
mod
,
"kernel_arg_access_qual"
)
return
ir
.
Module
.
add_m
eta
d
ata
(
mod
,
[
name
]
+
consts
)
def
gen_arg_type
(
fn
):
...
...
@@ -280,9 +285,9 @@ def gen_arg_type(fn):
"""
mod
=
fn
.
module
fnty
=
fn
.
type
.
pointee
consts
=
[
lc
.
MetaDataString
.
get
(
mod
,
str
(
a
))
for
a
in
fnty
.
args
]
name
=
lc
.
MetaDataString
.
get
(
mod
,
"kernel_arg_type"
)
return
lc
.
M
eta
D
ata
.
get
(
mod
,
[
name
]
+
consts
)
consts
=
[
ir
.
MetaDataString
(
mod
,
str
(
a
))
for
a
in
fnty
.
args
]
name
=
ir
.
MetaDataString
(
mod
,
"kernel_arg_type"
)
return
ir
.
Module
.
add_m
eta
d
ata
(
mod
,
[
name
]
+
consts
)
def
gen_arg_type_qual
(
fn
):
...
...
@@ -291,9 +296,9 @@ def gen_arg_type_qual(fn):
"""
mod
=
fn
.
module
fnty
=
fn
.
type
.
pointee
consts
=
[
lc
.
MetaDataString
.
get
(
mod
,
""
)
for
_
in
fnty
.
args
]
name
=
lc
.
MetaDataString
.
get
(
mod
,
"kernel_arg_type_qual"
)
return
lc
.
M
eta
D
ata
.
get
(
mod
,
[
name
]
+
consts
)
consts
=
[
ir
.
MetaDataString
(
mod
,
""
)
for
_
in
fnty
.
args
]
name
=
ir
.
MetaDataString
(
mod
,
"kernel_arg_type_qual"
)
return
ir
.
Module
.
add_m
eta
d
ata
(
mod
,
[
name
]
+
consts
)
def
gen_arg_base_type
(
fn
):
...
...
@@ -302,9 +307,9 @@ def gen_arg_base_type(fn):
"""
mod
=
fn
.
module
fnty
=
fn
.
type
.
pointee
consts
=
[
lc
.
MetaDataString
.
get
(
mod
,
str
(
a
))
for
a
in
fnty
.
args
]
name
=
lc
.
MetaDataString
.
get
(
mod
,
"kernel_arg_base_type"
)
return
lc
.
M
eta
D
ata
.
get
(
mod
,
[
name
]
+
consts
)
consts
=
[
ir
.
MetaDataString
(
mod
,
str
(
a
))
for
a
in
fnty
.
args
]
name
=
ir
.
MetaDataString
(
mod
,
"kernel_arg_base_type"
)
return
ir
.
Module
.
add_m
eta
d
ata
(
mod
,
[
name
]
+
consts
)
class
HSACallConv
(
MinimalCallConv
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment