"git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "f42835aa875b5e76b979dac4f26c45e1ac3ed5e5"
Unverified Commit 3d328d61 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

issue/92 添加InferEngine,支持多线程推理

parent 0794f307
......@@ -3,9 +3,12 @@
namespace infinilm::engine::distributed {
CommunicationGroup::CommunicationGroup(const DistConfig &dist_config)
: dist_config_(dist_config),
CommunicationGroup::CommunicationGroup(const DistConfig &dist_config, infinicore::Device::Type device_type)
: dist_config_(dist_config), device_type_(device_type),
communicators_(std::vector<infinicclComm_t>(dist_config.tp_device_ids.size(), nullptr)) {
if (infinicore::context::getDevice().getType() != device_type_) {
infinicore::context::setDevice(infinicore::Device(device_type_, 0));
}
if (dist_config_.tp_device_ids.size() > 1) {
RUN_INFINI(infinicclCommInitAll(
(infiniDevice_t)infinicore::context::getDevice().getType(),
......@@ -15,15 +18,21 @@ CommunicationGroup::CommunicationGroup(const DistConfig &dist_config)
}
}
const DistConfig &CommunicationGroup::getDistConfig() const {
const DistConfig &CommunicationGroup::get_dist_config() const {
return dist_config_;
}
RankCommunicator CommunicationGroup::getRankCommunicator(int rank) const {
RankCommunicator rc;
rc.info = dist_config_.getRankInfo(rank);
rc.comm = communicators_[rank];
return rc;
RankInfo CommunicationGroup::get_rank_info(int rank) const {
RankInfo info;
info.tp_size = dist_config_.tp_device_ids.size();
info.tp_rank = rank;
info.device = infinicore::Device(device_type_, dist_config_.tp_device_ids[rank]);
info.comm = communicators_[rank];
return info;
}
int CommunicationGroup::get_world_size() const {
return dist_config_.tp_device_ids.size();
}
CommunicationGroup::~CommunicationGroup() {
......
......@@ -5,29 +5,48 @@
#include <infiniccl.h>
#include <infinicore/context/context.hpp>
#include <sstream>
#include <vector>
namespace infinilm::engine::distributed {
// Communicator each rank will hold
struct RankCommunicator {
RankInfo info;
struct RankInfo {
// Device Type and ID assigned to this rank
infinicore::Device device;
// Tensor parallelism size
int tp_size;
// Tensor parallelism rank number of this rank
int tp_rank;
// Communicator handle
infinicclComm_t comm;
RankInfo(infinicore::Device _device = infinicore::context::getDevice())
: tp_size(1), tp_rank(0), device(_device), comm(nullptr){};
std::string to_string() const {
std::stringstream ss;
ss << "RankInfo: device=" << device.toString() << ", tp_size=" << tp_size << ", tp_rank=" << tp_rank;
return ss.str();
}
};
// The communication group managed by model infer engine
class CommunicationGroup {
public:
explicit CommunicationGroup(const DistConfig &dist_config);
explicit CommunicationGroup(const DistConfig &dist_config, infinicore::Device::Type device_type);
const DistConfig &get_dist_config() const;
const DistConfig &getDistConfig() const;
RankInfo get_rank_info(int rank) const;
RankCommunicator getRankCommunicator(int rank) const;
int get_world_size() const;
~CommunicationGroup();
protected:
DistConfig dist_config_;
infinicore::Device::Type device_type_;
std::vector<infinicclComm_t> communicators_;
};
......
#include "dist_config.hpp"
namespace infinilm::engine::distributed {
// ---------------- RankInfo ----------------
RankInfo::RankInfo()
: tp_size(1), tp_rank(0), device_id(0) {}
RankInfo::RankInfo(int tp_size_, int tp_rank_, int device_id_)
: tp_size(tp_size_), tp_rank(tp_rank_), device_id(device_id_) {}
RankInfo::RankInfo(int tp_size_, int tp_rank_)
: RankInfo(tp_size_, tp_rank_, tp_rank_) {}
// ---------------- DistConfig ----------------
DistConfig::DistConfig()
: tp_device_ids{0} {}
......@@ -28,8 +14,16 @@ DistConfig::DistConfig(int tp_size)
DistConfig::DistConfig(const std::vector<int> &tp_device_ids_)
: tp_device_ids(tp_device_ids_) {}
RankInfo DistConfig::getRankInfo(int rank) const {
return RankInfo(tp_device_ids.size(), rank, tp_device_ids[rank]);
DistConfig::operator std::string() const {
std::string repr = "DistConfig(tp_device_ids=[";
for (size_t i = 0; i < tp_device_ids.size(); ++i) {
repr += std::to_string(tp_device_ids[i]);
if (i != tp_device_ids.size() - 1) {
repr += ", ";
}
}
repr += "])";
return repr;
}
} // namespace infinilm::engine::distributed
#pragma once
#include <string>
#include <vector>
namespace infinilm::engine::distributed {
struct RankInfo {
// Tensor parallelism size
int tp_size;
// Tensor parallelism rank number of this rank
int tp_rank;
// Device ID assigned to this rank
int device_id;
RankInfo();
RankInfo(int tp_size_, int tp_rank_, int device_id_);
RankInfo(int tp_size_, int tp_rank_);
};
struct DistConfig {
// Device IDs for each rank in tensor parallelism
std::vector<int> tp_device_ids;
......@@ -25,7 +13,7 @@ struct DistConfig {
explicit DistConfig(int tp_size);
explicit DistConfig(const std::vector<int> &tp_device_ids_);
RankInfo getRankInfo(int rank) const;
explicit operator std::string() const;
};
} // namespace infinilm::engine::distributed
#include "infer_engine.hpp"
namespace infinilm::engine {
//------------------------------------------------------
// Constructor
//------------------------------------------------------
InferEngine::InferEngine(
const std::any &config,
const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type)
: communication_group_(distributed_config, device_type),
model_config_(config) {
spdlog::info("Launch InferEngine with {}", std::string(distributed_config));
// Create one RankWorker per rank
int world_size = communication_group_.get_world_size();
workers_.reserve(world_size);
for (int r = 0; r < world_size; ++r) {
workers_.emplace_back(std::make_unique<RankWorker>(model_config_, communication_group_.get_rank_info(r)));
}
}
//------------------------------------------------------
// load_param
//------------------------------------------------------
void InferEngine::load_param(const std::string &name, const infinicore::Tensor &param) {
// Load the parameter on all workers
for (auto &worker : workers_) {
worker->load_param(name, param);
}
}
//------------------------------------------------------
// generate
//------------------------------------------------------
infinicore::Tensor InferEngine::generate(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids) {
// Trigger each worker to run inference
for (auto &worker : workers_) {
worker->run(std::vector<std::any>({input_ids, position_ids}));
}
return workers_[0]->get_output();
}
//------------------------------------------------------
// Destructor
//------------------------------------------------------
InferEngine::~InferEngine() {
// Close all workers
for (auto &worker : workers_) {
worker->close();
}
}
const distributed::DistConfig &InferEngine::get_dist_config() const {
return communication_group_.get_dist_config();
}
} // namespace infinilm::engine
#pragma once
#include "distributed/distributed.hpp"
#include "infinicore/tensor.hpp"
#include "rank_worker.hpp"
#include <any>
#include <vector>
namespace infinilm::engine {
class InferEngine {
public:
InferEngine(
const std::any &config,
const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType());
// Load a parameter to all workers (each can extract its shard inside RankWorker)
void load_param(const std::string &name, const infinicore::Tensor &param);
// Run a single forward pass on all workers and return the outputs from all ranks
infinicore::Tensor generate(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids);
~InferEngine();
const distributed::DistConfig &get_dist_config() const;
protected:
std::vector<std::unique_ptr<RankWorker>> workers_;
distributed::CommunicationGroup communication_group_;
std::any model_config_;
};
} // namespace infinilm::engine
#include "rank_worker.hpp"
#include "../models/model_factory.hpp"
#include <iostream>
#include <spdlog/spdlog.h>
#include <stdexcept>
namespace infinilm::engine {
RankWorker::RankWorker(const std::any &model_config,
const distributed::RankInfo &rank_info)
: model_config_(model_config),
rank_info_(rank_info),
job_cmd_(Command::INIT),
has_job_(false),
job_done_(false),
should_exit_(false),
init_done_(false) {
// start the thread
thread_ = std::thread(&RankWorker::thread_loop, this);
// Wait until the worker thread finishes initialization (model created)
std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return init_done_; });
}
std::string RankWorker::info() const {
std::stringstream ss;
ss << "RankWorker{";
// Rank related
ss << rank_info_.to_string() << " ";
// Flags
ss << "| init_done: " << (init_done_ ? "true" : "false") << " ";
ss << "| should_exit: " << (should_exit_ ? "true" : "false") << " ";
ss << "| has_job: " << (has_job_ ? "true" : "false") << " ";
ss << "| job_done: " << (job_done_ ? "true" : "false") << " ";
ss << "}";
return ss.str();
}
//------------------------------------------------------
// load_param -- synchronous (blocks until worker finishes loading)
//------------------------------------------------------
void RankWorker::load_param(const std::string &name,
const infinicore::Tensor &param) {
{
std::lock_guard<std::mutex> lock(mutex_);
// If the worker is stopping, don't submit new jobs.
if (should_exit_) {
throw std::runtime_error("RankWorker is closing; cannot load_param");
}
pending_param_name_ = name;
pending_param_ = param;
job_cmd_ = Command::LOAD;
has_job_ = true;
job_done_ = false;
}
cv_.notify_all();
// Wait for job completion
std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return job_done_ || should_exit_; });
if (should_exit_) {
throw std::runtime_error("RankWorker stopped while loading parameter");
}
}
//------------------------------------------------------
// run -- synchronous (blocks until worker finishes forward)
//------------------------------------------------------
void RankWorker::run(const std::vector<std::any> &args) {
{
std::lock_guard<std::mutex> lock(mutex_);
if (should_exit_) {
throw std::runtime_error("RankWorker is closing; cannot run");
}
pending_args_ = args;
job_cmd_ = Command::RUN;
has_job_ = true;
job_done_ = false;
}
cv_.notify_all();
// Wait for job completion
std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return job_done_ || should_exit_; });
if (should_exit_) {
throw std::runtime_error("RankWorker stopped while running");
}
}
//------------------------------------------------------
// close -- request shutdown and join thread
//------------------------------------------------------
void RankWorker::close() {
{
std::lock_guard<std::mutex> lock(mutex_);
should_exit_ = true;
has_job_ = false; // don't keep old jobs pending
job_cmd_ = Command::STOP;
}
cv_.notify_all();
if (thread_.joinable()) {
thread_.join();
}
}
//------------------------------------------------------
// get_output (thread safe)
//------------------------------------------------------
infinicore::Tensor RankWorker::get_output() {
std::lock_guard<std::mutex> lock(mutex_);
return output_;
}
//------------------------------------------------------
// thread_loop
//------------------------------------------------------
void RankWorker::thread_loop() {
try {
// Initialize device & model outside of holding the main mutex to avoid blocking callers.
infinicore::context::setDevice(rank_info_.device);
// Create model using factory (may be expensive)
model_ = InfinilmModelFactory::createModel(model_config_, rank_info_);
// Signal that initialization is done
{
std::lock_guard<std::mutex> lk(mutex_);
init_done_ = true;
}
cv_.notify_all();
// Main loop: wait for jobs or exit
while (true) {
Command local_cmd = Command::INIT;
std::string local_param_name;
infinicore::Tensor local_param;
std::vector<std::any> local_args;
// Wait for a job or exit
{
std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return has_job_ || should_exit_; });
if (should_exit_) {
break;
}
// capture job data and clear has_job_
local_cmd = job_cmd_;
if (local_cmd == Command::LOAD) {
local_param_name = pending_param_name_;
local_param = pending_param_;
} else if (local_cmd == Command::RUN) {
local_args = pending_args_;
}
// mark job as being processed
has_job_ = false;
job_done_ = false;
} // unlock mutex while executing the job
// Execute job outside the lock
if (local_cmd == Command::LOAD) {
try {
model_->load_parameter(local_param_name, local_param);
} catch (const std::exception &e) {
// convert exceptions to a safe behavior: set should_exit_ and notify caller
std::lock_guard<std::mutex> lk(mutex_);
should_exit_ = true;
job_done_ = true;
cv_.notify_all();
// rethrow so the thread can be joined and caller sees an error if desired (optional)
spdlog::error("[{}] exception during load_parameter_: {}\n", info(), e.what());
break;
}
// signal completion
{
std::lock_guard<std::mutex> lk(mutex_);
job_done_ = true;
}
cv_.notify_all();
} else if (local_cmd == Command::RUN) {
try {
auto out = model_->forward(local_args);
{
std::lock_guard<std::mutex> lk(mutex_);
output_ = std::move(out);
job_done_ = true;
}
cv_.notify_all();
} catch (const std::exception &e) {
std::lock_guard<std::mutex> lk(mutex_);
should_exit_ = true;
job_done_ = true;
cv_.notify_all();
spdlog::error("[{}] exception during forward: {}\n", info(), e.what());
break;
}
} else {
// Shouldn't reach here (no-op)
}
} // while
} catch (const std::exception &e) {
// Top-level exception: ensure any waiters are woken and the thread exits cleanly.
{
std::lock_guard<std::mutex> lk(mutex_);
should_exit_ = true;
job_done_ = true;
}
cv_.notify_all();
spdlog::error("[{}] fatal exception in thread_loop: {} \n", info(), e.what());
}
}
} // namespace infinilm::engine
#pragma once
#include "../models/model_factory.hpp"
#include "distributed/distributed.hpp"
#include <any>
#include <condition_variable>
#include <mutex>
#include <string>
#include <thread>
#include <vector>
namespace infinilm::engine {
class RankWorker {
enum class Command {
INIT,
LOAD,
RUN,
STOP
};
public:
RankWorker(const std::any &model_config,
const distributed::RankInfo &rank_info);
// Submit a parameter load job and wait until the load completes on the worker thread.
void load_param(const std::string &name,
const infinicore::Tensor &param);
// Submit a run (forward) job and wait until completion.
// The result can be retrieved with get_output().
void run(const std::vector<std::any> &args);
// Request worker shutdown and join the thread.
void close();
// Thread-safe accessor for last output produced by RUN.
infinicore::Tensor get_output();
std::string info() const;
private:
void thread_loop();
private:
// Worker properties
std::any model_config_;
distributed::RankInfo rank_info_;
std::shared_ptr<InfinilmModel> model_;
// Command for the pending job (protected by mutex_)
Command job_cmd_;
// State flags (protected by mutex_)
bool has_job_ = false; // a job is pending
bool job_done_ = false; // last job completed
bool should_exit_ = false; // request to stop
bool init_done_ = false; // initialization finished
// Task payloads (protected by mutex)
std::string pending_param_name_;
infinicore::Tensor pending_param_;
std::vector<std::any> pending_args_;
// Output (protected by mutex)
infinicore::Tensor output_;
// Thread sync
std::thread thread_;
std::mutex mutex_;
std::condition_variable cv_;
};
} // namespace infinilm::engine
#pragma once
#include "infinicore/nn/module.hpp"
#include <any>
namespace infinilm {
class InfinilmModel : public infinicore::nn::Module {
public:
virtual ~InfinilmModel() = default;
virtual infinicore::Tensor forward(std::vector<std::any>) const = 0;
};
} // namespace infinilm
......@@ -6,6 +6,8 @@ namespace infinilm::models::llama {
LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, const infinicore::Device &device,
infinicore::DataType dtype) {
device_ = device;
// Initialize base model
INFINICORE_NN_MODULE_INIT(model, config, device, dtype);
......@@ -17,10 +19,11 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, const infinicore::
}
infinicore::Tensor LlamaForCausalLM::forward(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids,
std::vector<void *> *kv_caches) const {
const infinicore::Tensor &position_ids,
std::vector<void *> *kv_caches) const {
// 1. Forward through base model to get hidden states
auto hidden_states = model_->forward(input_ids, position_ids, kv_caches);
auto position_ids_device = position_ids->to(device_);
auto hidden_states = model_->forward(input_ids, position_ids_device, kv_caches);
// 2. Apply language modeling head to get logits
auto logits = lm_head_->forward(hidden_states);
......@@ -28,4 +31,22 @@ infinicore::Tensor LlamaForCausalLM::forward(const infinicore::Tensor &input_ids
return logits;
}
infinicore::Tensor LlamaForCausalLM::forward(std::vector<std::any> args) const {
if (args.size() < 2) {
throw std::invalid_argument("LlamaForCausalLM::forward requires at least 2 arguments: input_ids and position_ids");
}
// Extract input tensors from args
const auto &input_ids = std::any_cast<const infinicore::Tensor &>(args[0]);
const auto &position_ids = std::any_cast<const infinicore::Tensor &>(args[1]);
// Optional KV caches
std::vector<void *> *kv_caches = nullptr;
if (args.size() >= 3) {
kv_caches = std::any_cast<std::vector<void *> *>(args[2]);
}
return forward(input_ids, position_ids, kv_caches);
}
} // namespace infinilm::models::llama
#pragma once
#include "../infinilm_model.hpp"
#include "llama_model.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/device.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/tensor.hpp"
#include "infinicore/device.hpp"
namespace infinilm::models::llama {
......@@ -16,17 +18,17 @@ namespace infinilm::models::llama {
*
* This matches the structure of HuggingFace's LlamaForCausalLM.
*/
class LlamaForCausalLM : public infinicore::nn::Module {
class LlamaForCausalLM : public InfinilmModel {
public:
/**
* @brief Construct LlamaForCausalLM module
*
* @param config Model configuration
* @param device Device to create tensors on
* @param dtype Optional data type for model parameters (defaults to F32)
* @param dtype Optional data type for model parameters (defaults to BF16)
*/
LlamaForCausalLM(const LlamaConfig &config, const infinicore::Device &device,
infinicore::DataType dtype = infinicore::DataType::F32);
infinicore::DataType dtype = infinicore::DataType::BF16);
/**
* @brief Forward pass: compute language modeling logits
......@@ -40,8 +42,10 @@ public:
* will be added when integrating with the inference engine.
*/
infinicore::Tensor forward(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids,
std::vector<void *> *kv_caches = nullptr) const;
const infinicore::Tensor &position_ids,
std::vector<void *> *kv_caches = nullptr) const;
infinicore::Tensor forward(std::vector<std::any> args) const override;
// Module information
const LlamaConfig &config() const { return model_->config(); }
......@@ -54,7 +58,6 @@ protected:
// Language modeling head
INFINICORE_NN_MODULE(infinicore::nn::Linear, lm_head);
};
} // namespace infinilm::models::llama
#include "model_factory.hpp"
#include "llama/llama.hpp"
namespace infinilm {
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(const std::any &config, engine::distributed::RankInfo rank_info) {
if (config.type() == typeid(models::llama::LlamaConfig)) {
const auto &llama_config = std::any_cast<models::llama::LlamaConfig>(config);
return std::make_shared<models::llama::LlamaForCausalLM>(llama_config, rank_info.device);
} else {
throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
}
}
} // namespace infinilm
#pragma once
#include "infinilm_model.hpp"
#include "../engine/distributed/distributed.hpp"
namespace infinilm {
class InfinilmModelFactory {
public:
static std::shared_ptr<InfinilmModel> createModel(const std::any &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
};
} // namespace infinilm
#include "models/llama.hpp"
#include <pybind11/pybind11.h>
#include "models/llama.hpp"
#include "engine.hpp"
namespace py = pybind11;
PYBIND11_MODULE(_infinilm, m) {
m.doc() = "InfiniLM Llama model Python bindings";
infinilm::models::llama::bind_llama(m);
infinilm::engine::distributed::bind_dist_config(m);
infinilm::engine::bind_infer_engine(m);
}
#include "../engine/infer_engine.hpp"
#include "infinicore/tensor.hpp"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
namespace infinilm::engine::distributed {
inline void bind_dist_config(py::module &m) {
py::class_<DistConfig>(m, "DistConfig")
.def(py::init<>(), "Default constructor, empty device list")
.def(py::init<int>(), py::arg("tp_size"),
"Constructor with tensor parallel size, auto-assigns device IDs 0..tp_size-1")
.def(py::init<const std::vector<int> &>(), py::arg("tp_device_ids"),
"Constructor with explicit device IDs")
.def_readwrite("tp_device_ids", &DistConfig::tp_device_ids,
"List of device IDs used in tensor parallelism")
.def("__repr__", [](const DistConfig &cfg) {
return std::string(cfg);
})
.def("__str__", [](const DistConfig &cfg) {
return std::string(cfg);
});
}
} // namespace infinilm::engine::distributed
namespace infinilm::engine {
inline void bind_infer_engine(py::module &m) {
py::class_<InferEngine, std::shared_ptr<InferEngine>>(m, "InferEngine")
.def(py::init([](const infinilm::models::llama::LlamaConfig &cfg,
const infinilm::engine::distributed::DistConfig &dist,
infinicore::Device::Type dev) {
return new InferEngine(std::any(cfg), dist, dev);
}),
py::arg("config"),
py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType())
.def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)")
.def(
"generate", [](InferEngine &self, py::object input_ids, py::object position_ids) -> infinicore::Tensor {
return self.generate(input_ids.cast<infinicore::Tensor>(), position_ids.cast<infinicore::Tensor>());
},
"Run inference on all ranks with arbitrary arguments");
// Optionally, you can add __repr__ for debugging
m.attr("InferEngine").attr("__repr__") = py::cpp_function([](const InferEngine &self) {
return "<InferEngine: " + std::string(self.get_dist_config()) + ">";
});
}
} // namespace infinilm::engine
......@@ -3,6 +3,7 @@ from transformers import AutoTokenizer
from tokenizers import decoders as _dec
from infinilm.modeling_utils import get_model_state_dict
import infinilm
from infinilm.distributed import DistConfig
import argparse
import sys
import time
......@@ -75,6 +76,13 @@ def get_args():
default="How are you",
help="input prompt",
)
parser.add_argument(
"--tp",
type=int,
default=None,
help="total rank for tensor parallel",
)
return parser.parse_args()
......@@ -91,7 +99,11 @@ def test(
# 创建模型,
# ---------------------------------------------------------------------------- #
model = infinilm.AutoLlamaModel.from_pretrained(
model_path, device=infini_device, dtype=infini_dtype, backend=backend
model_path,
device=infini_device,
dtype=infini_dtype,
backend=backend,
distributed_config=DistConfig(args.tp),
)
# ---------------------------------------------------------------------------- #
......
from .models import AutoLlamaModel
from . import distributed
__all__ = ["AutoLlamaModel"]
__all__ = ["AutoLlamaModel", "distributed"]
from .dist_config import DistConfig
__all__ = ["DistConfig"]
class DistConfig:
"""
Distributed Model Configuration.
"""
def __init__(self, tp_size=None, tp_device_ids=None):
from infinilm.lib import _infinilm
if tp_size is not None and tp_device_ids is not None:
raise ValueError("Provide either tp_size OR tp_device_ids, not both")
if tp_size is not None:
self._underlying = _infinilm.DistConfig(tp_size)
elif tp_device_ids is not None:
self._underlying = _infinilm.DistConfig(tp_device_ids)
else:
self._underlying = _infinilm.DistConfig()
@property
def tp_device_ids(self):
return self._underlying.tp_device_ids
@tp_device_ids.setter
def tp_device_ids(self, value):
self._underlying.tp_device_ids = list(value)
def __repr__(self):
return repr(self._underlying)
def __str__(self):
return str(self._underlying)
......@@ -47,18 +47,14 @@ class GenerationMixin:
self,
bs: int,
seq_length: int,
device: infinicore.device,
) -> infinicore.Tensor:
"""Calculates `position_ids` for the pre-fill stage"""
position_ids_list = [list(range(0, seq_length)) for i in range(bs)]
return infinicore.from_list(
position_ids_list, dtype=infinicore.int64, device=device
)
return infinicore.from_list(position_ids_list, dtype=infinicore.int64)
def prepare_inputs_for_generation(
self,
device: infinicore.device,
past_key_values: Optional[Cache] = None,
**kwargs,
):
......@@ -79,9 +75,7 @@ class GenerationMixin:
if current_position_ids is None:
# prill阶段
bs, seq_len = kwargs["input_ids"].shape[0:2]
model_inputs["position_ids"] = self._get_initial_position_ids(
bs, seq_len, device
)
model_inputs["position_ids"] = self._get_initial_position_ids(bs, seq_len)
else:
# decoder 阶段
......@@ -119,8 +113,8 @@ class GenerationMixin:
self,
input_ids: infinicore.Tensor,
max_new_tokens: int,
device: infinicore.device,
tokenizer,
stop_on_eos=True,
**kwargs,
):
model_kwargs = kwargs
......@@ -141,8 +135,8 @@ class GenerationMixin:
result = self._sample(
input_ids,
max_new_tokens=max_new_tokens,
device=device,
tokenizer=tokenizer,
stop_on_eos=stop_on_eos,
**model_kwargs,
)
return result
......@@ -151,8 +145,8 @@ class GenerationMixin:
self,
input_ids: infinicore.Tensor,
max_new_tokens: int,
device: infinicore.device,
tokenizer,
stop_on_eos=True,
**model_kwargs,
):
r"""
......@@ -187,7 +181,7 @@ class GenerationMixin:
# -------------------------------------------------------------------------- #
# prepare model inputs
# -------------------------------------------------------------------------- #
model_inputs = self.prepare_inputs_for_generation(device, **model_kwargs)
model_inputs = self.prepare_inputs_for_generation(**model_kwargs)
model_kwargs["position_ids"] = model_inputs["position_ids"]
......@@ -240,7 +234,7 @@ class GenerationMixin:
output_content += output_str
print(output_str, end="", flush=True)
if token_id in eos_token_id_list:
if stop_on_eos and token_id in eos_token_id_list:
break
print("\n</s>")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment