Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
f9761a29
Commit
f9761a29
authored
Jan 19, 2026
by
wooway777
Browse files
issue/900 - maintains classic embedding for devices yet to be worked on
parent
eb34d4d6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
79 additions
and
14 deletions
+79
-14
src/infinicore/nn/embedding.cc
src/infinicore/nn/embedding.cc
+78
-11
src/infinicore/ops/embedding/embedding.cc
src/infinicore/ops/embedding/embedding.cc
+1
-3
No files found.
src/infinicore/nn/embedding.cc
View file @
f9761a29
...
...
@@ -43,20 +43,87 @@ Embedding::Embedding(size_t num_embeddings,
}
Tensor
Embedding
::
forward
(
const
Tensor
&
indices
)
const
{
//
Ensure indices are on the same device as weight
// This avoids synchronous memcpy in ops layer which would hurt performance
Tensor
indices_on_device
=
indices
;
if
(
indices
->
device
()
!=
device_
)
{
indices_on_device
=
indices
->
to
(
device_
);
//
TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach
auto
device_type
=
device_
.
getType
();
if
(
device_type
==
Device
::
Type
::
NVIDIA
||
device_type
==
Device
::
Type
::
ILUVATAR
||
device_type
==
Device
::
Type
::
METAX
||
device_type
==
Device
::
Type
::
MOORE
)
{
// Use op::embedding which supports device-side input and batch dimension
return
op
::
embedding
(
indices
->
contiguous
()
->
to
(
device_
)
,
weight_
)
;
}
// Ensure indices are contiguous for efficient access
// op::embedding now supports device-side input for graph recording
Tensor
indices_contiguous
=
indices_on_device
->
is_contiguous
()
?
indices_on_device
:
indices_on_device
->
contiguous
();
// Get the shape of indices
auto
indices_shape
=
indices
->
shape
();
// Use op::embedding which now supports device-side input and batch dimension
// This enables full graph recording support without synchronization
return
op
::
embedding
(
indices_contiguous
,
weight_
);
// Output shape: indices_shape + [embedding_dim]
std
::
vector
<
size_t
>
output_shape
=
indices_shape
;
output_shape
.
push_back
(
embedding_dim_
);
// Create output tensor on the same device as weight
auto
out
=
Tensor
::
empty
(
output_shape
,
weight_
->
dtype
(),
weight_
->
device
());
// Flatten indices for sequential row copies
auto
cpu_device
=
Device
(
Device
::
Type
::
CPU
,
0
);
auto
indices_cpu
=
indices
->
to
(
cpu_device
)
->
contiguous
();
// Calculate total number of lookups
size_t
num_lookups
=
1
;
for
(
auto
dim
:
indices_shape
)
{
num_lookups
*=
dim
;
}
const
size_t
row_bytes
=
embedding_dim_
*
dsize
(
weight_
->
dtype
());
// Source and destination base pointers
auto
*
weight_base
=
weight_
->
data
();
auto
*
out_base
=
out
->
data
();
// Helper lambda to read index based on dtype with bounds checking
auto
read_index
=
[
&
](
size_t
i
)
->
int64_t
{
auto
dtype
=
indices_cpu
->
dtype
();
if
(
dtype
==
DataType
::
I32
)
{
const
auto
*
data
=
reinterpret_cast
<
const
int32_t
*>
(
indices_cpu
->
data
());
return
static_cast
<
int64_t
>
(
data
[
i
]);
}
else
if
(
dtype
==
DataType
::
I64
)
{
const
auto
*
data
=
reinterpret_cast
<
const
int64_t
*>
(
indices_cpu
->
data
());
return
data
[
i
];
}
else
if
(
dtype
==
DataType
::
U32
)
{
const
auto
*
data
=
reinterpret_cast
<
const
uint32_t
*>
(
indices_cpu
->
data
());
return
static_cast
<
int64_t
>
(
data
[
i
]);
}
else
if
(
dtype
==
DataType
::
U64
)
{
const
auto
*
data
=
reinterpret_cast
<
const
uint64_t
*>
(
indices_cpu
->
data
());
uint64_t
val
=
data
[
i
];
// Check if value can fit in int64_t
if
(
val
>
static_cast
<
uint64_t
>
(
std
::
numeric_limits
<
int64_t
>::
max
()))
{
throw
std
::
out_of_range
(
"Index value out of range for int64_t: "
+
std
::
to_string
(
val
));
}
return
static_cast
<
int64_t
>
(
val
);
}
else
{
throw
std
::
runtime_error
(
"Embedding indices must be integer type, got dtype="
+
std
::
to_string
(
static_cast
<
int
>
(
dtype
)));
}
};
if
(
weight_
->
device
().
getType
()
==
Device
::
Type
::
CPU
)
{
// CPU path: memcpy row by row
for
(
size_t
i
=
0
;
i
<
num_lookups
;
++
i
)
{
int64_t
idx
=
read_index
(
i
);
if
(
idx
<
0
||
idx
>=
static_cast
<
int64_t
>
(
num_embeddings_
))
{
throw
std
::
out_of_range
(
"Index out of range: "
+
std
::
to_string
(
idx
)
+
" (num_embeddings="
+
std
::
to_string
(
num_embeddings_
)
+
")"
);
}
std
::
memcpy
(
out_base
+
i
*
row_bytes
,
weight_base
+
idx
*
row_bytes
,
row_bytes
);
}
}
else
{
// Device path: use stream-ordered D2D copies
for
(
size_t
i
=
0
;
i
<
num_lookups
;
++
i
)
{
int64_t
idx
=
read_index
(
i
);
if
(
idx
<
0
||
idx
>=
static_cast
<
int64_t
>
(
num_embeddings_
))
{
throw
std
::
out_of_range
(
"Index out of range: "
+
std
::
to_string
(
idx
)
+
" (num_embeddings="
+
std
::
to_string
(
num_embeddings_
)
+
")"
);
}
context
::
memcpyD2D
(
out_base
+
i
*
row_bytes
,
weight_base
+
idx
*
row_bytes
,
row_bytes
);
}
}
return
out
;
}
std
::
string
Embedding
::
extra_repr
()
const
{
...
...
src/infinicore/ops/embedding/embedding.cc
View file @
f9761a29
#include "infinicore/ops/embedding.hpp"
#include "../../utils.hpp"
#include "infinicore/context/context.hpp"
#include <cstring>
#include <stdexcept>
namespace
infinicore
::
op
{
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL
(
Embedding
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment