Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
liupw
numba-DTK
Commits
3e5f428e
Commit
3e5f428e
authored
Apr 06, 2024
by
dugupeiwen
Browse files
Remove use of llvmlite.llvmpy for 0.58
parent
5be111ee
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
103 additions
and
57 deletions
+103
-57
numba/roc/codegen.py
numba/roc/codegen.py
+44
-6
numba/roc/compiler.py
numba/roc/compiler.py
+4
-4
numba/roc/dispatch.py
numba/roc/dispatch.py
+4
-2
numba/roc/hsaimpl.py
numba/roc/hsaimpl.py
+18
-17
numba/roc/target.py
numba/roc/target.py
+33
-28
No files found.
numba/roc/codegen.py
View file @
3e5f428e
from
llvmlite
import
binding
as
ll
from
llvmlite
import
binding
as
ll
from
llvmlite.llvmpy
import
core
as
lc
# from llvmlite.llvmpy import core as lc
import
llvmlite.ir
as
llvmir
from
numba.core
import
utils
from
numba.core
import
utils
from
numba.core.codegen
import
BaseCPU
Codegen
,
CodeLibrary
from
numba.core.codegen
import
Codegen
,
CodeLibrary
,
CPUCodeLibrary
from
.hlc
import
DATALAYOUT
,
TRIPLE
,
hlc
from
.hlc
import
DATALAYOUT
,
TRIPLE
,
hlc
class
HSACodeLibrary
(
CPUCodeLibrary
):
class
HSACodeLibrary
(
CodeLibrary
):
def
_optimize_functions
(
self
,
ll_module
):
def
_optimize_functions
(
self
,
ll_module
):
pass
pass
...
@@ -25,17 +25,55 @@ class HSACodeLibrary(CodeLibrary):
...
@@ -25,17 +25,55 @@ class HSACodeLibrary(CodeLibrary):
return
str
(
out
.
hsail
)
return
str
(
out
.
hsail
)
class
JITHSACodegen
(
BaseCPUCodegen
):
# class JITHSACodegen(Codegen):
# _library_class = HSACodeLibrary
# def _init(self, llvm_module):
# assert list(llvm_module.global_variables) == [], "Module isn't empty"
# self._data_layout = DATALAYOUT[utils.MACHINE_BITS]
# self._target_data = ll.create_target_data(self._data_layout)
# def _create_empty_module(self, name):
# ir_module = llvmir.Module(name)
# ir_module.triple = TRIPLE
# return ir_module
# def _module_pass_manager(self):
# raise NotImplementedError
# def _function_pass_manager(self, llvm_module):
# raise NotImplementedError
# def _add_module(self, module):
# pass
class
JITHSACodegen
(
Codegen
):
_library_class
=
HSACodeLibrary
_library_class
=
HSACodeLibrary
def
__init__
(
self
,
module_name
):
# initialize_llvm()
ll
.
initialize
()
ll
.
initialize_native_target
()
ll
.
initialize_native_asmprinter
()
self
.
_data_layout
=
None
self
.
_llvm_module
=
ll
.
parse_assembly
(
str
(
self
.
_create_empty_module
(
module_name
)))
self
.
_llvm_module
.
name
=
"global_codegen_module"
# self._rtlinker = RuntimeLinker()
self
.
_init
(
self
.
_llvm_module
)
def
_init
(
self
,
llvm_module
):
def
_init
(
self
,
llvm_module
):
assert
list
(
llvm_module
.
global_variables
)
==
[],
"Module isn't empty"
assert
list
(
llvm_module
.
global_variables
)
==
[],
"Module isn't empty"
self
.
_data_layout
=
DATALAYOUT
[
utils
.
MACHINE_BITS
]
self
.
_data_layout
=
DATALAYOUT
[
utils
.
MACHINE_BITS
]
self
.
_target_data
=
ll
.
create_target_data
(
self
.
_data_layout
)
self
.
_target_data
=
ll
.
create_target_data
(
self
.
_data_layout
)
def
_create_empty_module
(
self
,
name
):
def
_create_empty_module
(
self
,
name
):
ir_module
=
l
c
.
Module
(
name
)
ir_module
=
l
lvmir
.
Module
(
name
)
ir_module
.
triple
=
TRIPLE
ir_module
.
triple
=
TRIPLE
if
self
.
_data_layout
:
ir_module
.
data_layout
=
self
.
_data_layout
return
ir_module
return
ir_module
def
_module_pass_manager
(
self
):
def
_module_pass_manager
(
self
):
...
...
numba/roc/compiler.py
View file @
3e5f428e
...
@@ -26,10 +26,10 @@ def compile_hsa(pyfunc, return_type, args, debug):
...
@@ -26,10 +26,10 @@ def compile_hsa(pyfunc, return_type, args, debug):
# TODO handle debug flag
# TODO handle debug flag
flags
=
compiler
.
Flags
()
flags
=
compiler
.
Flags
()
# Do not compile (generate native code), just lower (to LLVM)
# Do not compile (generate native code), just lower (to LLVM)
flags
.
set
(
'
no_compile
'
)
flags
.
no_compile
=
True
flags
.
set
(
'
no_cpython_wrapper
'
)
flags
.
no_cpython_wrapper
=
True
flags
.
set
(
'
no_cfunc_wrapper
'
)
flags
.
no_cfunc_wrapper
=
True
flags
.
unset
(
'nrt'
)
flags
.
nrt
=
False
# Run compilation pipeline
# Run compilation pipeline
cres
=
compiler
.
compile_extra
(
typingctx
=
typingctx
,
cres
=
compiler
.
compile_extra
(
typingctx
=
typingctx
,
targetctx
=
targetctx
,
targetctx
=
targetctx
,
...
...
numba/roc/dispatch.py
View file @
3e5f428e
import
numpy
as
np
import
numpy
as
np
from
numba.np.ufunc.deviceufunc
import
(
UFuncMechanism
,
GenerializedUFunc
,
# from numba.np.ufunc.deviceufunc import (UFuncMechanism, GenerializedUFunc,
# GUFuncCallSteps)
from
numba.np.ufunc.deviceufunc
import
(
UFuncMechanism
,
GeneralizedUFunc
,
GUFuncCallSteps
)
GUFuncCallSteps
)
from
numba.roc.hsadrv.driver
import
dgpu_present
from
numba.roc.hsadrv.driver
import
dgpu_present
import
numba.roc.hsadrv.devicearray
as
devicearray
import
numba.roc.hsadrv.devicearray
as
devicearray
...
@@ -119,7 +121,7 @@ class _HsaGUFuncCallSteps(GUFuncCallSteps):
...
@@ -119,7 +121,7 @@ class _HsaGUFuncCallSteps(GUFuncCallSteps):
kernel
.
configure
(
nelem
,
min
(
nelem
,
64
))(
*
args
)
kernel
.
configure
(
nelem
,
min
(
nelem
,
64
))(
*
args
)
class
HSAGenerializedUFunc
(
Gener
i
alizedUFunc
):
class
HSAGenerializedUFunc
(
GeneralizedUFunc
):
@
property
@
property
def
_call_steps
(
self
):
def
_call_steps
(
self
):
return
_HsaGUFuncCallSteps
return
_HsaGUFuncCallSteps
...
...
numba/roc/hsaimpl.py
View file @
3e5f428e
import
operator
import
operator
from
functools
import
reduce
from
functools
import
reduce
from
llvmlite.llvmpy.core
import
Type
#
from llvmlite.llvmpy.core import Type
import
llvmlite.llvmpy.core
as
lc
#
import llvmlite.llvmpy.core as lc
import
llvmlite.binding
as
ll
import
llvmlite.binding
as
ll
from
llvmlite
import
ir
from
llvmlite
import
ir
from
numba
import
roc
from
numba
import
roc
from
numba.core.imputils
import
Registry
from
numba.core.imputils
import
Registry
from
numba.core
import
types
,
cgutils
from
numba.core
import
types
,
cgutils
from
numba.core.itanium_mangler
import
mangle_c
,
mangle
,
mangle_type
# from numba.core.itanium_mangler import mangle_c, mangle, mangle_type
from
numba.core.itanium_mangler
import
mangle
,
mangle_type
from
numba.core.typing.npydecl
import
parse_dtype
from
numba.core.typing.npydecl
import
parse_dtype
from
numba.roc
import
target
from
numba.roc
import
target
from
numba.roc
import
stubs
from
numba.roc
import
stubs
...
@@ -19,13 +20,13 @@ from numba.roc import enums
...
@@ -19,13 +20,13 @@ from numba.roc import enums
registry
=
Registry
()
registry
=
Registry
()
lower
=
registry
.
lower
lower
=
registry
.
lower
_void_value
=
lc
.
Constant
.
null
(
lc
.
Type
.
pointer
(
lc
.
Type
.
int
(
8
))
)
_void_value
=
ir
.
Constant
(
ir
.
PointerType
(
ir
.
IntType
(
8
)),
None
)
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
def
_declare_function
(
context
,
builder
,
name
,
sig
,
cargs
,
def
_declare_function
(
context
,
builder
,
name
,
sig
,
cargs
,
mangler
=
mangle
_c
):
mangler
=
mangle
):
"""Insert declaration for a opencl builtin function.
"""Insert declaration for a opencl builtin function.
Uses the Itanium mangler.
Uses the Itanium mangler.
...
@@ -50,11 +51,11 @@ def _declare_function(context, builder, name, sig, cargs,
...
@@ -50,11 +51,11 @@ def _declare_function(context, builder, name, sig, cargs,
"""
"""
mod
=
builder
.
module
mod
=
builder
.
module
if
sig
.
return_type
==
types
.
void
:
if
sig
.
return_type
==
types
.
void
:
llretty
=
lc
.
Type
.
void
()
llretty
=
ir
.
VoidType
()
else
:
else
:
llretty
=
context
.
get_value_type
(
sig
.
return_type
)
llretty
=
context
.
get_value_type
(
sig
.
return_type
)
llargs
=
[
context
.
get_value_type
(
t
)
for
t
in
sig
.
args
]
llargs
=
[
context
.
get_value_type
(
t
)
for
t
in
sig
.
args
]
fnty
=
Type
.
f
unction
(
llretty
,
llargs
)
fnty
=
ir
.
F
unction
Type
(
llretty
,
llargs
)
mangled
=
mangler
(
name
,
cargs
)
mangled
=
mangler
(
name
,
cargs
)
fn
=
mod
.
get_or_insert_function
(
fnty
,
mangled
)
fn
=
mod
.
get_or_insert_function
(
fnty
,
mangled
)
fn
.
calling_convention
=
target
.
CC_SPIR_FUNC
fn
.
calling_convention
=
target
.
CC_SPIR_FUNC
...
@@ -154,7 +155,7 @@ def mem_fence_impl(context, builder, sig, args):
...
@@ -154,7 +155,7 @@ def mem_fence_impl(context, builder, sig, args):
@
lower
(
stubs
.
wavebarrier
)
@
lower
(
stubs
.
wavebarrier
)
def
wavebarrier_impl
(
context
,
builder
,
sig
,
args
):
def
wavebarrier_impl
(
context
,
builder
,
sig
,
args
):
assert
not
args
assert
not
args
fnty
=
Type
.
f
unction
(
Type
.
void
(),
[])
fnty
=
ir
.
F
unctionType
(
ir
.
VoidType
(),
[])
fn
=
builder
.
module
.
declare_intrinsic
(
'llvm.amdgcn.wave.barrier'
,
fnty
=
fnty
)
fn
=
builder
.
module
.
declare_intrinsic
(
'llvm.amdgcn.wave.barrier'
,
fnty
=
fnty
)
builder
.
call
(
fn
,
[])
builder
.
call
(
fn
,
[])
return
_void_value
return
_void_value
...
@@ -166,12 +167,12 @@ def activelanepermute_wavewidth_impl(context, builder, sig, args):
...
@@ -166,12 +167,12 @@ def activelanepermute_wavewidth_impl(context, builder, sig, args):
assert
sig
.
args
[
0
]
==
sig
.
args
[
2
]
assert
sig
.
args
[
0
]
==
sig
.
args
[
2
]
elem_type
=
sig
.
args
[
0
]
elem_type
=
sig
.
args
[
0
]
bitwidth
=
elem_type
.
bitwidth
bitwidth
=
elem_type
.
bitwidth
intbitwidth
=
Type
.
int
(
bitwidth
)
intbitwidth
=
ir
.
Int
Type
(
bitwidth
)
i32
=
Type
.
int
(
32
)
i32
=
ir
.
Int
Type
(
32
)
i1
=
Type
.
int
(
1
)
i1
=
ir
.
Int
Type
(
1
)
name
=
"__hsail_activelanepermute_wavewidth_b{0}"
.
format
(
bitwidth
)
name
=
"__hsail_activelanepermute_wavewidth_b{0}"
.
format
(
bitwidth
)
fnty
=
Type
.
f
unction
(
intbitwidth
,
[
intbitwidth
,
i32
,
intbitwidth
,
i1
])
fnty
=
ir
.
F
unction
Type
(
intbitwidth
,
[
intbitwidth
,
i32
,
intbitwidth
,
i1
])
fn
=
builder
.
module
.
get_or_insert_function
(
fnty
,
name
=
name
)
fn
=
builder
.
module
.
get_or_insert_function
(
fnty
,
name
=
name
)
fn
.
calling_convention
=
target
.
CC_SPIR_FUNC
fn
.
calling_convention
=
target
.
CC_SPIR_FUNC
...
@@ -188,14 +189,14 @@ def _gen_ds_permute(intrinsic_name):
...
@@ -188,14 +189,14 @@ def _gen_ds_permute(intrinsic_name):
"""
"""
assert
sig
.
return_type
==
sig
.
args
[
1
]
assert
sig
.
return_type
==
sig
.
args
[
1
]
idx
,
src
=
args
idx
,
src
=
args
i32
=
Type
.
int
(
32
)
i32
=
ir
.
Int
Type
(
32
)
fnty
=
Type
.
f
unction
(
i32
,
[
i32
,
i32
])
fnty
=
ir
.
F
unction
Type
(
i32
,
[
i32
,
i32
])
fn
=
builder
.
module
.
declare_intrinsic
(
intrinsic_name
,
fnty
=
fnty
)
fn
=
builder
.
module
.
declare_intrinsic
(
intrinsic_name
,
fnty
=
fnty
)
# the args are byte addressable, VGPRs are 4 wide so mul idx by 4
# the args are byte addressable, VGPRs are 4 wide so mul idx by 4
# the idx might be an int64, this is ok to trunc to int32 as
# the idx might be an int64, this is ok to trunc to int32 as
# wavefront_size is never likely overflow an int32
# wavefront_size is never likely overflow an int32
idx
=
builder
.
trunc
(
idx
,
i32
)
idx
=
builder
.
trunc
(
idx
,
i32
)
four
=
lc
.
Constant
.
int
(
i32
,
4
)
four
=
ir
.
Constant
(
i32
,
4
)
idx
=
builder
.
mul
(
idx
,
four
)
idx
=
builder
.
mul
(
idx
,
four
)
# bit cast is so float32 works as packed i32, the return casts back
# bit cast is so float32 works as packed i32, the return casts back
result
=
builder
.
call
(
fn
,
(
idx
,
builder
.
bitcast
(
src
,
i32
)))
result
=
builder
.
call
(
fn
,
(
idx
,
builder
.
bitcast
(
src
,
i32
)))
...
@@ -258,7 +259,7 @@ def hsail_smem_alloc_array_tuple(context, builder, sig, args):
...
@@ -258,7 +259,7 @@ def hsail_smem_alloc_array_tuple(context, builder, sig, args):
def
_generic_array
(
context
,
builder
,
shape
,
dtype
,
symbol_name
,
addrspace
):
def
_generic_array
(
context
,
builder
,
shape
,
dtype
,
symbol_name
,
addrspace
):
elemcount
=
reduce
(
operator
.
mul
,
shape
,
1
)
elemcount
=
reduce
(
operator
.
mul
,
shape
,
1
)
lldtype
=
context
.
get_data_type
(
dtype
)
lldtype
=
context
.
get_data_type
(
dtype
)
laryty
=
Type
.
array
(
lldtype
,
elemcount
)
laryty
=
ir
.
ArrayType
(
lldtype
,
elemcount
)
if
addrspace
==
target
.
SPIR_LOCAL_ADDRSPACE
:
if
addrspace
==
target
.
SPIR_LOCAL_ADDRSPACE
:
lmod
=
builder
.
module
lmod
=
builder
.
module
...
@@ -269,7 +270,7 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace):
...
@@ -269,7 +270,7 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace):
if
elemcount
<=
0
:
if
elemcount
<=
0
:
raise
ValueError
(
"array length <= 0"
)
raise
ValueError
(
"array length <= 0"
)
else
:
else
:
gvmem
.
linkage
=
lc
.
LINKAGE_INTERNAL
gvmem
.
linkage
=
'internal'
if
dtype
not
in
types
.
number_domain
:
if
dtype
not
in
types
.
number_domain
:
raise
TypeError
(
"unsupported type: %s"
%
dtype
)
raise
TypeError
(
"unsupported type: %s"
%
dtype
)
...
...
numba/roc/target.py
View file @
3e5f428e
import
re
import
re
from
llvmlite.llvmpy
import
core
as
lc
# from llvmlite.llvmpy import core as lc
from
llvmlite
import
ir
as
llvmir
# from llvmlite import ir as llvmir
from
llvmlite
import
ir
from
llvmlite
import
binding
as
ll
from
llvmlite
import
binding
as
ll
from
numba.core
import
typing
,
types
,
utils
,
datamodel
,
cgutils
from
numba.core
import
typing
,
types
,
utils
,
datamodel
,
cgutils
from
numba.core.utils
import
cached_property
# from numba.core.utils import cached_property
from
functools
import
cached_property
from
numba.core.base
import
BaseContext
from
numba.core.base
import
BaseContext
from
numba.core.callconv
import
MinimalCallConv
from
numba.core.callconv
import
MinimalCallConv
from
numba.roc
import
codegen
from
numba.roc
import
codegen
...
@@ -65,6 +67,9 @@ class HSATargetContext(BaseContext):
...
@@ -65,6 +67,9 @@ class HSATargetContext(BaseContext):
implement_powi_as_math_call
=
True
implement_powi_as_math_call
=
True
generic_addrspace
=
SPIR_GENERIC_ADDRSPACE
generic_addrspace
=
SPIR_GENERIC_ADDRSPACE
def
__init__
(
self
,
typingctx
,
target
=
'ROCm'
):
super
().
__init__
(
typingctx
,
target
)
def
init
(
self
):
def
init
(
self
):
self
.
_internal_codegen
=
codegen
.
JITHSACodegen
(
"numba.hsa.jit"
)
self
.
_internal_codegen
=
codegen
.
JITHSACodegen
(
"numba.hsa.jit"
)
self
.
_target_data
=
\
self
.
_target_data
=
\
...
@@ -89,7 +94,7 @@ class HSATargetContext(BaseContext):
...
@@ -89,7 +94,7 @@ class HSATargetContext(BaseContext):
def
target_data
(
self
):
def
target_data
(
self
):
return
self
.
_target_data
return
self
.
_target_data
def
mangler
(
self
,
name
,
argtypes
):
def
mangler
(
self
,
name
,
argtypes
,
*
,
abi_tags
=
(),
uid
=
None
):
def
repl
(
m
):
def
repl
(
m
):
ch
=
m
.
group
(
0
)
ch
=
m
.
group
(
0
)
return
"_%X_"
%
ord
(
ch
)
return
"_%X_"
%
ord
(
ch
)
...
@@ -119,7 +124,7 @@ class HSATargetContext(BaseContext):
...
@@ -119,7 +124,7 @@ class HSATargetContext(BaseContext):
arginfo
=
self
.
get_arg_packer
(
argtypes
)
arginfo
=
self
.
get_arg_packer
(
argtypes
)
def
sub_gen_with_global
(
lty
):
def
sub_gen_with_global
(
lty
):
if
isinstance
(
lty
,
llvm
ir
.
PointerType
):
if
isinstance
(
lty
,
ir
.
PointerType
):
return
(
lty
.
pointee
.
as_pointer
(
SPIR_GLOBAL_ADDRSPACE
),
return
(
lty
.
pointee
.
as_pointer
(
SPIR_GLOBAL_ADDRSPACE
),
lty
.
addrspace
)
lty
.
addrspace
)
return
lty
,
None
return
lty
,
None
...
@@ -129,13 +134,13 @@ class HSATargetContext(BaseContext):
...
@@ -129,13 +134,13 @@ class HSATargetContext(BaseContext):
arginfo
.
argument_types
))
arginfo
.
argument_types
))
else
:
else
:
llargtys
=
changed
=
()
llargtys
=
changed
=
()
wrapperfnty
=
lc
.
Type
.
function
(
lc
.
Type
.
void
(),
llargtys
)
wrapperfnty
=
ir
.
FunctionType
(
ir
.
VoidType
(),
llargtys
)
wrapper_module
=
self
.
create_module
(
"hsa.kernel.wrapper"
)
wrapper_module
=
self
.
create_module
(
"hsa.kernel.wrapper"
)
wrappername
=
'hsaPy_{name}'
.
format
(
name
=
func
.
name
)
wrappername
=
'hsaPy_{name}'
.
format
(
name
=
func
.
name
)
argtys
=
list
(
arginfo
.
argument_types
)
argtys
=
list
(
arginfo
.
argument_types
)
fnty
=
lc
.
Type
.
function
(
lc
.
Type
.
int
(),
fnty
=
ir
.
FunctionType
(
ir
.
IntType
(),
[
self
.
call_conv
.
get_return_type
(
[
self
.
call_conv
.
get_return_type
(
types
.
pyobject
)]
+
argtys
)
types
.
pyobject
)]
+
argtys
)
...
@@ -144,7 +149,7 @@ class HSATargetContext(BaseContext):
...
@@ -144,7 +149,7 @@ class HSATargetContext(BaseContext):
wrapper
=
wrapper_module
.
add_function
(
wrapperfnty
,
name
=
wrappername
)
wrapper
=
wrapper_module
.
add_function
(
wrapperfnty
,
name
=
wrappername
)
builder
=
lc
.
Builder
(
wrapper
.
append_basic_block
(
''
))
builder
=
ir
.
IR
Builder
(
wrapper
.
append_basic_block
(
''
))
# Adjust address space of each kernel argument
# Adjust address space of each kernel argument
fixed_args
=
[]
fixed_args
=
[]
...
@@ -193,7 +198,7 @@ class HSATargetContext(BaseContext):
...
@@ -193,7 +198,7 @@ class HSATargetContext(BaseContext):
"""
"""
Handle addrspacecast
Handle addrspacecast
"""
"""
ptras
=
llvm
ir
.
PointerType
(
src
.
type
.
pointee
,
addrspace
=
addrspace
)
ptras
=
ir
.
PointerType
(
src
.
type
.
pointee
,
addrspace
=
addrspace
)
return
builder
.
addrspacecast
(
src
,
ptras
)
return
builder
.
addrspacecast
(
src
,
ptras
)
...
@@ -213,7 +218,7 @@ def set_hsa_kernel(fn):
...
@@ -213,7 +218,7 @@ def set_hsa_kernel(fn):
# Mark kernels
# Mark kernels
ocl_kernels
=
mod
.
get_or_insert_named_metadata
(
"opencl.kernels"
)
ocl_kernels
=
mod
.
get_or_insert_named_metadata
(
"opencl.kernels"
)
ocl_kernels
.
add
(
lc
.
M
eta
D
ata
.
get
(
mod
,
[
fn
,
ocl_kernels
.
add
(
ir
.
Module
.
add_m
eta
d
ata
(
mod
,
[
fn
,
gen_arg_addrspace_md
(
fn
),
gen_arg_addrspace_md
(
fn
),
gen_arg_access_qual_md
(
fn
),
gen_arg_access_qual_md
(
fn
),
gen_arg_type
(
fn
),
gen_arg_type
(
fn
),
...
@@ -221,16 +226,16 @@ def set_hsa_kernel(fn):
...
@@ -221,16 +226,16 @@ def set_hsa_kernel(fn):
gen_arg_base_type
(
fn
)]))
gen_arg_base_type
(
fn
)]))
# SPIR version 2.0
# SPIR version 2.0
make_constant
=
lambda
x
:
lc
.
Constant
.
int
(
lc
.
Type
.
int
(),
x
)
make_constant
=
lambda
x
:
ir
.
Constant
(
ir
.
IntType
(),
x
)
spir_version_constant
=
[
make_constant
(
x
)
for
x
in
SPIR_VERSION
]
spir_version_constant
=
[
make_constant
(
x
)
for
x
in
SPIR_VERSION
]
spir_version
=
mod
.
get_or_insert_named_metadata
(
"opencl.spir.version"
)
spir_version
=
mod
.
get_or_insert_named_metadata
(
"opencl.spir.version"
)
if
not
spir_version
.
operands
:
if
not
spir_version
.
operands
:
spir_version
.
add
(
lc
.
M
eta
D
ata
.
get
(
mod
,
spir_version_constant
))
spir_version
.
add
(
ir
.
Module
.
add_m
eta
d
ata
(
mod
,
spir_version_constant
))
ocl_version
=
mod
.
get_or_insert_named_metadata
(
"opencl.ocl.version"
)
ocl_version
=
mod
.
get_or_insert_named_metadata
(
"opencl.ocl.version"
)
if
not
ocl_version
.
operands
:
if
not
ocl_version
.
operands
:
ocl_version
.
add
(
lc
.
M
eta
D
ata
.
get
(
mod
,
spir_version_constant
))
ocl_version
.
add
(
ir
.
Module
.
add_m
eta
d
ata
(
mod
,
spir_version_constant
))
## The following metadata does not seem to be necessary
## The following metadata does not seem to be necessary
# Other metadata
# Other metadata
...
@@ -259,9 +264,9 @@ def gen_arg_addrspace_md(fn):
...
@@ -259,9 +264,9 @@ def gen_arg_addrspace_md(fn):
else
:
else
:
codes
.
append
(
SPIR_PRIVATE_ADDRSPACE
)
codes
.
append
(
SPIR_PRIVATE_ADDRSPACE
)
consts
=
[
lc
.
Constant
.
int
(
lc
.
Type
.
int
(),
x
)
for
x
in
codes
]
consts
=
[
ir
.
Constant
(
ir
.
IntType
(),
x
)
for
x
in
codes
]
name
=
lc
.
MetaDataString
.
get
(
mod
,
"kernel_arg_addr_space"
)
name
=
ir
.
MetaDataString
(
mod
,
"kernel_arg_addr_space"
)
return
lc
.
M
eta
D
ata
.
get
(
mod
,
[
name
]
+
consts
)
return
ir
.
Module
.
add_m
eta
d
ata
(
mod
,
[
name
]
+
consts
)
def
gen_arg_access_qual_md
(
fn
):
def
gen_arg_access_qual_md
(
fn
):
...
@@ -269,9 +274,9 @@ def gen_arg_access_qual_md(fn):
...
@@ -269,9 +274,9 @@ def gen_arg_access_qual_md(fn):
Generate kernel_arg_access_qual metadata
Generate kernel_arg_access_qual metadata
"""
"""
mod
=
fn
.
module
mod
=
fn
.
module
consts
=
[
lc
.
MetaDataString
.
get
(
mod
,
"none"
)]
*
len
(
fn
.
args
)
consts
=
[
ir
.
MetaDataString
(
mod
,
"none"
)]
*
len
(
fn
.
args
)
name
=
lc
.
MetaDataString
.
get
(
mod
,
"kernel_arg_access_qual"
)
name
=
ir
.
MetaDataString
(
mod
,
"kernel_arg_access_qual"
)
return
lc
.
M
eta
D
ata
.
get
(
mod
,
[
name
]
+
consts
)
return
ir
.
Module
.
add_m
eta
d
ata
(
mod
,
[
name
]
+
consts
)
def
gen_arg_type
(
fn
):
def
gen_arg_type
(
fn
):
...
@@ -280,9 +285,9 @@ def gen_arg_type(fn):
...
@@ -280,9 +285,9 @@ def gen_arg_type(fn):
"""
"""
mod
=
fn
.
module
mod
=
fn
.
module
fnty
=
fn
.
type
.
pointee
fnty
=
fn
.
type
.
pointee
consts
=
[
lc
.
MetaDataString
.
get
(
mod
,
str
(
a
))
for
a
in
fnty
.
args
]
consts
=
[
ir
.
MetaDataString
(
mod
,
str
(
a
))
for
a
in
fnty
.
args
]
name
=
lc
.
MetaDataString
.
get
(
mod
,
"kernel_arg_type"
)
name
=
ir
.
MetaDataString
(
mod
,
"kernel_arg_type"
)
return
lc
.
M
eta
D
ata
.
get
(
mod
,
[
name
]
+
consts
)
return
ir
.
Module
.
add_m
eta
d
ata
(
mod
,
[
name
]
+
consts
)
def
gen_arg_type_qual
(
fn
):
def
gen_arg_type_qual
(
fn
):
...
@@ -291,9 +296,9 @@ def gen_arg_type_qual(fn):
...
@@ -291,9 +296,9 @@ def gen_arg_type_qual(fn):
"""
"""
mod
=
fn
.
module
mod
=
fn
.
module
fnty
=
fn
.
type
.
pointee
fnty
=
fn
.
type
.
pointee
consts
=
[
lc
.
MetaDataString
.
get
(
mod
,
""
)
for
_
in
fnty
.
args
]
consts
=
[
ir
.
MetaDataString
(
mod
,
""
)
for
_
in
fnty
.
args
]
name
=
lc
.
MetaDataString
.
get
(
mod
,
"kernel_arg_type_qual"
)
name
=
ir
.
MetaDataString
(
mod
,
"kernel_arg_type_qual"
)
return
lc
.
M
eta
D
ata
.
get
(
mod
,
[
name
]
+
consts
)
return
ir
.
Module
.
add_m
eta
d
ata
(
mod
,
[
name
]
+
consts
)
def
gen_arg_base_type
(
fn
):
def
gen_arg_base_type
(
fn
):
...
@@ -302,9 +307,9 @@ def gen_arg_base_type(fn):
...
@@ -302,9 +307,9 @@ def gen_arg_base_type(fn):
"""
"""
mod
=
fn
.
module
mod
=
fn
.
module
fnty
=
fn
.
type
.
pointee
fnty
=
fn
.
type
.
pointee
consts
=
[
lc
.
MetaDataString
.
get
(
mod
,
str
(
a
))
for
a
in
fnty
.
args
]
consts
=
[
ir
.
MetaDataString
(
mod
,
str
(
a
))
for
a
in
fnty
.
args
]
name
=
lc
.
MetaDataString
.
get
(
mod
,
"kernel_arg_base_type"
)
name
=
ir
.
MetaDataString
(
mod
,
"kernel_arg_base_type"
)
return
lc
.
M
eta
D
ata
.
get
(
mod
,
[
name
]
+
consts
)
return
ir
.
Module
.
add_m
eta
d
ata
(
mod
,
[
name
]
+
consts
)
class
HSACallConv
(
MinimalCallConv
):
class
HSACallConv
(
MinimalCallConv
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment