Commit 8b59f4fe authored by Catheriany's avatar Catheriany
Browse files

Merge remote-tracking branch 'origin/main' into issue/204

parents 16506fc0 df1c6b5d
......@@ -72,6 +72,7 @@ def test(
x_stride,
w_dtype=torch.float16,
dtype=torch.float16,
sync=None
):
print(
f"Testing RMS_Norm on {torch_device} with y_shape:{y_shape} x_shape:{x_shape} w_shape:{w_shape}"
......@@ -89,9 +90,11 @@ def test(
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([x, y], [x_stride, y_stride])
]
x_tensor, y_tensor, w_tensor = [to_tensor(tensor, lib) for tensor in [x, y, w]]
if sync is not None:
sync()
descriptor = infiniopRMSNormDescriptor_t()
check_error(
......
......@@ -117,6 +117,7 @@ def test(
y_strides=None,
inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float32,
sync=None
):
if inplace == Inplace.INPLACE_X:
y_strides = x_strides
......@@ -147,8 +148,8 @@ def test(
else:
y_tensor = to_tensor(y, lib)
if torch_device == "npu":
synchronize_device(torch_device)
if sync is not None:
sync()
check_error(
lib.infiniopCreateRoPEDescriptor(
......
......@@ -162,6 +162,9 @@ target("infinirt")
if has_config("nv-gpu") then
add_deps("infinirt-cuda")
end
if has_config("cambricon-mlu") then
add_deps("infinirt-cambricon")
end
if has_config("ascend-npu") then
add_deps("infinirt-ascend")
end
......
......@@ -50,9 +50,8 @@ target("infiniop-ascend")
add_files("$(projectdir)/src/infiniop/devices/ascend/*.cc", "$(projectdir)/src/infiniop/ops/*/ascend/*.cc")
-- Add operator
-- TODO: add it back after ascend-kernels is fixed
-- add_rules("ascend-kernels")
-- add_links(builddir.."/libascend_kernels.a")
add_rules("ascend-kernels")
add_links(builddir.."/libascend_kernels.a")
target_end()
target("infinirt-ascend")
......
......@@ -50,3 +50,13 @@ target("infiniop-cambricon")
add_files(mlu_files, {rule = "mlu"})
end
target_end()
target("infinirt-cambricon")
set_kind("static")
add_deps("infini-utils")
set_languages("cxx17")
on_install(function (target) end)
-- Add include dirs
add_files("../src/infinirt/bang/*.cc")
add_cxflags("-lstdc++ -Wall -Werror -fPIC")
target_end()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment