Commit f9761a29 authored by wooway777's avatar wooway777
Browse files

issue/900 - maintains classic embedding for devices yet to be worked on

parent eb34d4d6
......@@ -43,20 +43,87 @@ Embedding::Embedding(size_t num_embeddings,
}
Tensor Embedding::forward(const Tensor &indices) const {
// Ensure indices are on the same device as weight
// This avoids synchronous memcpy in ops layer which would hurt performance
Tensor indices_on_device = indices;
if (indices->device() != device_) {
indices_on_device = indices->to(device_);
// TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach
auto device_type = device_.getType();
if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE) {
// Use op::embedding which supports device-side input and batch dimension
return op::embedding(indices->contiguous()->to(device_), weight_);
}
// Ensure indices are contiguous for efficient access
// op::embedding now supports device-side input for graph recording
Tensor indices_contiguous = indices_on_device->is_contiguous() ? indices_on_device : indices_on_device->contiguous();
// Get the shape of indices
auto indices_shape = indices->shape();
// Use op::embedding which now supports device-side input and batch dimension
// This enables full graph recording support without synchronization
return op::embedding(indices_contiguous, weight_);
// Output shape: indices_shape + [embedding_dim]
std::vector<size_t> output_shape = indices_shape;
output_shape.push_back(embedding_dim_);
// Create output tensor on the same device as weight
auto out = Tensor::empty(output_shape, weight_->dtype(), weight_->device());
// Flatten indices for sequential row copies
auto cpu_device = Device(Device::Type::CPU, 0);
auto indices_cpu = indices->to(cpu_device)->contiguous();
// Calculate total number of lookups
size_t num_lookups = 1;
for (auto dim : indices_shape) {
num_lookups *= dim;
}
const size_t row_bytes = embedding_dim_ * dsize(weight_->dtype());
// Source and destination base pointers
auto *weight_base = weight_->data();
auto *out_base = out->data();
// Helper lambda to read index based on dtype with bounds checking
auto read_index = [&](size_t i) -> int64_t {
auto dtype = indices_cpu->dtype();
if (dtype == DataType::I32) {
const auto *data = reinterpret_cast<const int32_t *>(indices_cpu->data());
return static_cast<int64_t>(data[i]);
} else if (dtype == DataType::I64) {
const auto *data = reinterpret_cast<const int64_t *>(indices_cpu->data());
return data[i];
} else if (dtype == DataType::U32) {
const auto *data = reinterpret_cast<const uint32_t *>(indices_cpu->data());
return static_cast<int64_t>(data[i]);
} else if (dtype == DataType::U64) {
const auto *data = reinterpret_cast<const uint64_t *>(indices_cpu->data());
uint64_t val = data[i];
// Check if value can fit in int64_t
if (val > static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
throw std::out_of_range("Index value out of range for int64_t: " + std::to_string(val));
}
return static_cast<int64_t>(val);
} else {
throw std::runtime_error("Embedding indices must be integer type, got dtype=" + std::to_string(static_cast<int>(dtype)));
}
};
if (weight_->device().getType() == Device::Type::CPU) {
// CPU path: memcpy row by row
for (size_t i = 0; i < num_lookups; ++i) {
int64_t idx = read_index(i);
if (idx < 0 || idx >= static_cast<int64_t>(num_embeddings_)) {
throw std::out_of_range(
"Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")");
}
std::memcpy(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes);
}
} else {
// Device path: use stream-ordered D2D copies
for (size_t i = 0; i < num_lookups; ++i) {
int64_t idx = read_index(i);
if (idx < 0 || idx >= static_cast<int64_t>(num_embeddings_)) {
throw std::out_of_range(
"Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")");
}
context::memcpyD2D(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes);
}
}
return out;
}
std::string Embedding::extra_repr() const {
......
#include "infinicore/ops/embedding.hpp"
#include "../../utils.hpp"
#include "infinicore/context/context.hpp"
#include <cstring>
#include <stdexcept>
namespace infinicore::op {
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Embedding);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment