Commit efd602c8 authored by xuxzh1's avatar xuxzh1 🎱
Browse files

last

parent f1b779fc
install-server: install-server:
cd server && make install cd server && make install
install-custom-kernels: install-server-cpu:
if [ "$$BUILD_EXTENSIONS" = "True" ]; then cd server/custom_kernels && python setup.py install; else echo "Custom kernels are disabled, you need to set the BUILD_EXTENSIONS environment variable to 'True' in order to build them. (Please read the docs, kernels might not work on all hardware)"; fi cd server && make install-server
install-integration-tests:
cd integration-tests && pip install -r requirements.txt
cd clients/python && pip install .
install-router: install-router:
cd router && cargo install --path . cd router && cargo install --path . --debug
install-launcher: install-launcher:
cd launcher && cargo install --path . cd launcher && cargo install --path .
...@@ -17,7 +13,10 @@ install-launcher: ...@@ -17,7 +13,10 @@ install-launcher:
install-benchmark: install-benchmark:
cd benchmark && cargo install --path . cd benchmark && cargo install --path .
install: install-server install-router install-launcher install-custom-kernels install: install-server install-router install-launcher
install-cpu: install-server-cpu install-router install-launcher
server-dev: server-dev:
cd server && make run-dev cd server && make run-dev
...@@ -28,6 +27,10 @@ router-dev: ...@@ -28,6 +27,10 @@ router-dev:
rust-tests: install-router install-launcher rust-tests: install-router install-launcher
cargo test cargo test
install-integration-tests:
cd integration-tests && pip install -r requirements.txt
cd clients/python && pip install .
integration-tests: install-integration-tests integration-tests: install-integration-tests
pytest -s -vv -m "not private" integration-tests pytest -s -vv -m "not private" integration-tests
......
...@@ -34,19 +34,19 @@ Text Generation Inference(TGI)是一个用 Rust 和 Python 编写的框架 ...@@ -34,19 +34,19 @@ Text Generation Inference(TGI)是一个用 Rust 和 Python 编写的框架
基于现有python环境自己安装pytorch,triton,flash-att包: 基于现有python环境自己安装pytorch,triton,flash-att包:
**安装pytorch** **安装pytorch**
安装pytorch2.1.0,pytorch whl包下载目录:[https://cancon.hpccube.com:65024/4/main/pytorch](https://cancon.hpccube.com:65024/4/main/pytorch),根据python、dtk版本,下载对应pytorch2.1.0的whl包。安装命令如下: 安装pytorch2.1.0,pytorch whl包下载目录:[https://cancon.hpccube.com:65024/4/main/pytorch](https://cancon.hpccube.com:65024/4/main/pytorch),根据python、dtk版本,下载对应pytorch2.1.0的whl包。安装命令如下:
```shell ```bash
pip install torch* (下载的torch的whl包) pip install torch* (下载的torch的whl包)
pip install setuptools wheel pip install setuptools wheel
``` ```
**安装triton** **安装triton**
triton whl包下载:[https://cancon.hpccube.com:65024/4/main/triton](https://cancon.hpccube.com:65024/4/main/triton),需要根据python、dtk版本,下载对应triton 2.1的whl包 triton whl包下载:[https://cancon.hpccube.com:65024/4/main/triton](https://cancon.hpccube.com:65024/4/main/triton),需要根据python、dtk版本,下载对应triton 2.1的whl包
```shell ```bash
pip install triton* (下载的triton的whl包) pip install triton* (下载的triton的whl包)
``` ```
**安装flash-attn** **安装flash-attn**
flash_attn包下载:[https://cancon.hpccube.com:65024/4/main/flash_attn](https://cancon.hpccube.com:65024/4/main/flash_attn),需要根据python、dtk版本,下载对应flash_attn 2.0.4的whl包 flash_attn包下载:[https://cancon.hpccube.com:65024/4/main/flash_attn](https://cancon.hpccube.com:65024/4/main/flash_attn),需要根据python、dtk版本,下载对应flash_attn 2.0.4的whl包
```shell ```bash
pip install flash_attn* (下载的triton的whl包) pip install flash_attn* (下载的triton的whl包)
``` ```
...@@ -66,36 +66,41 @@ sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*' ...@@ -66,36 +66,41 @@ sudo unzip -o $PROTOC_ZIP -d /usr/local 'include/*'
rm -f $PROTOC_ZIP rm -f $PROTOC_ZIP
``` ```
3. 安装TGI Service 3. 安装TGI Service
``` ```bash
git clone http://developer.hpccube.com/codes/OpenDAS/text-generation-inference.git # 根据需要的分支进行切换 git clone http://developer.hpccube.com/codes/OpenDAS/text-generation-inference.git # 根据需要的分支进行切换
cd text-generation-inference cd text-generation-inference
#添加安装vllm exllama #安装exllama
cd server cd server
pip uninstall vllm #optional:如果是按方式一准备的环境,需要先卸载环境中默认的vllm
make install-vllm #安装定制版本的vllm
make install-exllama #安装exllama kernels make install-exllama #安装exllama kernels
make install-exllamav2 #安装exllmav2 kernels make install-exllamav2 #安装exllmav2 kernels
cd .. #回到项目根目录 cd .. #回到项目根目录
source $HOME/.cargo/env
BUILD_EXTENSIONS=True make install #安装text-generation服务 BUILD_EXTENSIONS=True make install #安装text-generation服务
``` ```
4. 安装benchmark 4. 安装benchmark
``` ```bash
cd text-generation-inference cd text-generation-inference
make install-benchmark make install-benchmark
``` ```
注意:若安装过程过慢,可以通过如下命令修改默认源提速。 注意:若安装过程过慢,可以通过如下命令修改默认源提速。
``` ```bash
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
``` ```
另外,`cargo install` 太慢也可以通过在`~/.cargo/config`中添加源来提速。 另外,`cargo install` 太慢也可以通过在`~/.cargo/config`中添加源来提速。
## 查看安装的版本号 ## 查看安装的版本号
``` ```bash
text-generation-launcher -V #版本号与官方版本同步 text-generation-launcher -V #版本号与官方版本同步
``` ```
## 使用前
```bash
export PYTORCH_TUNABLEOP_ENABLED=0
```
## Known Issue ## Known Issue
- -
## 参考资料 ## 参考资料
......
This diff is collapsed.
...@@ -497,7 +497,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec<Line<'a>> { ...@@ -497,7 +497,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec<Line<'a>> {
"Lowest: {:.2} {unit}", "Lowest: {:.2} {unit}",
data.iter() data.iter()
.min_by(|a, b| a.total_cmp(b)) .min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN) .unwrap_or(&f64::NAN)
), ),
Style::default().fg(Color::Reset), Style::default().fg(Color::Reset),
)]), )]),
...@@ -506,7 +506,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec<Line<'a>> { ...@@ -506,7 +506,7 @@ fn statis_spans<'a>(data: &[f64], unit: &'static str) -> Vec<Line<'a>> {
"Highest: {:.2} {unit}", "Highest: {:.2} {unit}",
data.iter() data.iter()
.max_by(|a, b| a.total_cmp(b)) .max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN) .unwrap_or(&f64::NAN)
), ),
Style::default().fg(Color::Reset), Style::default().fg(Color::Reset),
)]), )]),
...@@ -555,17 +555,17 @@ fn latency_throughput_chart<'a>( ...@@ -555,17 +555,17 @@ fn latency_throughput_chart<'a>(
let min_latency: f64 = *latency_iter let min_latency: f64 = *latency_iter
.clone() .clone()
.min_by(|a, b| a.total_cmp(b)) .min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
let max_latency: f64 = *latency_iter let max_latency: f64 = *latency_iter
.max_by(|a, b| a.total_cmp(b)) .max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
let min_throughput: f64 = *throughput_iter let min_throughput: f64 = *throughput_iter
.clone() .clone()
.min_by(|a, b| a.total_cmp(b)) .min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
let max_throughput: f64 = *throughput_iter let max_throughput: f64 = *throughput_iter
.max_by(|a, b| a.total_cmp(b)) .max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
// Char min max values // Char min max values
let min_x = if zoom { let min_x = if zoom {
......
...@@ -11,7 +11,7 @@ pub(crate) enum Event { ...@@ -11,7 +11,7 @@ pub(crate) enum Event {
/// Key press. /// Key press.
Key(event::KeyEvent), Key(event::KeyEvent),
/// Terminal resize. /// Terminal resize.
Resize(u16, u16), Resize,
} }
pub(crate) async fn terminal_event_task( pub(crate) async fn terminal_event_task(
...@@ -47,8 +47,8 @@ async fn event_loop(fps: u32, event_sender: mpsc::Sender<Event>) { ...@@ -47,8 +47,8 @@ async fn event_loop(fps: u32, event_sender: mpsc::Sender<Event>) {
if event::poll(Duration::from_secs(0)).expect("no events available") { if event::poll(Duration::from_secs(0)).expect("no events available") {
match event::read().expect("unable to read event") { match event::read().expect("unable to read event") {
event::Event::Key(e) => event_sender.send(Event::Key(e)).await.unwrap_or(()), event::Event::Key(e) => event_sender.send(Event::Key(e)).await.unwrap_or(()),
event::Event::Resize(w, h) => { event::Event::Resize(_w, _h) => {
event_sender.send(Event::Resize(w, h)).await.unwrap_or(()) event_sender.send(Event::Resize).await.unwrap_or(())
} }
_ => (), _ => (),
} }
......
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use text_generation_client::{ use text_generation_client::v3::{
Batch, CachedBatch, ClientError, NextTokenChooserParameters, Request, ShardedClient, Batch, CachedBatch, NextTokenChooserParameters, Request, ShardedClient,
StoppingCriteriaParameters, StoppingCriteriaParameters,
}; };
use text_generation_client::{Chunk, ClientError, Input};
use tokenizers::{Tokenizer, TruncationDirection}; use tokenizers::{Tokenizer, TruncationDirection};
use tokio::sync::{broadcast, mpsc}; use tokio::sync::{broadcast, mpsc};
...@@ -142,6 +143,9 @@ async fn prefill( ...@@ -142,6 +143,9 @@ async fn prefill(
.map(|id| Request { .map(|id| Request {
id: id.into(), id: id.into(),
prefill_logprobs: false, prefill_logprobs: false,
input_chunks: Some(Input {
chunks: vec![Chunk::Text(sequence.clone()).into()],
}),
inputs: sequence.clone(), inputs: sequence.clone(),
truncate: sequence_length, truncate: sequence_length,
parameters: Some(parameters.clone()), parameters: Some(parameters.clone()),
...@@ -151,6 +155,9 @@ async fn prefill( ...@@ -151,6 +155,9 @@ async fn prefill(
ignore_eos_token: true, // Will not stop even if a eos token is generated ignore_eos_token: true, // Will not stop even if a eos token is generated
}), }),
top_n_tokens: top_n_tokens.unwrap_or(0), top_n_tokens: top_n_tokens.unwrap_or(0),
blocks: vec![],
slots: vec![],
adapter_id: None,
}) })
.collect(); .collect();
...@@ -159,6 +166,7 @@ async fn prefill( ...@@ -159,6 +166,7 @@ async fn prefill(
requests, requests,
size: batch_size, size: batch_size,
max_tokens: batch_size * (sequence_length + decode_length), max_tokens: batch_size * (sequence_length + decode_length),
max_blocks: 0,
}; };
// Run prefill // Run prefill
......
...@@ -8,7 +8,7 @@ use crate::app::App; ...@@ -8,7 +8,7 @@ use crate::app::App;
use crate::event::Event; use crate::event::Event;
use crossterm::ExecutableCommand; use crossterm::ExecutableCommand;
use std::io; use std::io;
use text_generation_client::{GrammarType, NextTokenChooserParameters, ShardedClient}; use text_generation_client::v3::{GrammarType, NextTokenChooserParameters, ShardedClient};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tokio::sync::{broadcast, mpsc}; use tokio::sync::{broadcast, mpsc};
use tui::backend::CrosstermBackend; use tui::backend::CrosstermBackend;
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
/// and: https://github.com/orhun/rust-tui-template /// and: https://github.com/orhun/rust-tui-template
use clap::Parser; use clap::Parser;
use std::path::Path; use std::path::Path;
use text_generation_client::ShardedClient; use text_generation_client::v3::ShardedClient;
use tokenizers::{FromPretrainedParameters, Tokenizer}; use tokenizers::{FromPretrainedParameters, Tokenizer};
use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::util::SubscriberInitExt;
...@@ -147,7 +147,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> { ...@@ -147,7 +147,9 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing::info!("Downloading tokenizer"); tracing::info!("Downloading tokenizer");
// Parse Huggingface hub token // Parse Huggingface hub token
let auth_token = std::env::var("HUGGING_FACE_HUB_TOKEN").ok(); let auth_token = std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.ok();
// Download and instantiate tokenizer // Download and instantiate tokenizer
// We need to download it outside of the Tokio runtime // We need to download it outside of the Tokio runtime
......
...@@ -156,17 +156,17 @@ fn avg_min_max(data: &[f64]) -> (f64, f64, f64) { ...@@ -156,17 +156,17 @@ fn avg_min_max(data: &[f64]) -> (f64, f64, f64) {
let min = data let min = data
.iter() .iter()
.min_by(|a, b| a.total_cmp(b)) .min_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
let max = data let max = data
.iter() .iter()
.max_by(|a, b| a.total_cmp(b)) .max_by(|a, b| a.total_cmp(b))
.unwrap_or(&std::f64::NAN); .unwrap_or(&f64::NAN);
(average, *min, *max) (average, *min, *max)
} }
fn px(data: &[f64], p: u32) -> f64 { fn px(data: &[f64], p: u32) -> f64 {
let i = (f64::from(p) / 100.0 * data.len() as f64) as usize; let i = (f64::from(p) / 100.0 * data.len() as f64) as usize;
*data.get(i).unwrap_or(&std::f64::NAN) *data.get(i).unwrap_or(&f64::NAN)
} }
fn format_value(value: f64, unit: &'static str) -> String { fn format_value(value: f64, unit: &'static str) -> String {
......
...@@ -37,7 +37,7 @@ pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap<String, f ...@@ -37,7 +37,7 @@ pub(crate) fn percentiles(values: &[f64], pecents: &[i32]) -> BTreeMap<String, f
.iter() .iter()
.map(|&p| { .map(|&p| {
let i = (f64::from(p) / 100.0 * values.len() as f64) as usize; let i = (f64::from(p) / 100.0 * values.len() as f64) as usize;
(format!("p{p}"), *values.get(i).unwrap_or(&std::f64::NAN)) (format!("p{p}"), *values.get(i).unwrap_or(&f64::NAN))
}) })
.collect() .collect()
} }
...@@ -12,7 +12,12 @@ ...@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
__version__ = "0.6.0" __version__ = "0.7.0"
DEPRECATION_WARNING = (
"`text_generation` clients are deprecated and will be removed in the near future. "
"Please use the `InferenceClient` from the `huggingface_hub` package instead."
)
from text_generation.client import Client, AsyncClient from text_generation.client import Client, AsyncClient
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
This diff is collapsed.
from enum import Enum from enum import Enum
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator, ConfigDict
from typing import Optional, List, Union, Any from typing import Optional, List, Union, Any
from text_generation.errors import ValidationError from text_generation.errors import ValidationError
...@@ -46,30 +46,6 @@ class Tool(BaseModel): ...@@ -46,30 +46,6 @@ class Tool(BaseModel):
function: dict function: dict
class ChatCompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
message: Message
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
# Usage details of the chat completion
usage: Optional[Any] = None
class CompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
text: str
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
class Function(BaseModel): class Function(BaseModel):
name: Optional[str] name: Optional[str]
arguments: str arguments: str
...@@ -95,24 +71,41 @@ class Choice(BaseModel): ...@@ -95,24 +71,41 @@ class Choice(BaseModel):
finish_reason: Optional[str] = None finish_reason: Optional[str] = None
class ChatCompletionChunk(BaseModel): class CompletionRequest(BaseModel):
id: str # Model identifier
object: str
created: int
model: str model: str
system_fingerprint: str # Prompt
choices: List[Choice] prompt: str
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None
# The parameter for frequency penalty. 1.0 means no penalty
# Penalize new tokens based on their existing frequency in the text so far,
# decreasing the model's likelihood to repeat the same line verbatim.
frequency_penalty: Optional[float] = None
# Maximum number of tokens to generate
max_tokens: Optional[int] = None
# Flag to indicate streaming response
stream: bool = False
# Random sampling seed
seed: Optional[int] = None
# Sampling temperature
temperature: Optional[float] = None
# Top-p value for nucleus sampling
top_p: Optional[float] = None
# Stop generating tokens if a member of `stop` is generated
stop: Optional[List[str]] = None
class ChatComplete(BaseModel): class CompletionComplete(BaseModel):
# Chat completion details # Index of the chat completion
id: str index: int
object: str # Message associated with the chat completion
created: int text: str
model: str # Log probabilities for the chat completion
system_fingerprint: str logprobs: Optional[Any]
choices: List[ChatCompletionComplete] # Reason for completion
usage: Any finish_reason: str
class Completion(BaseModel): class Completion(BaseModel):
...@@ -163,6 +156,41 @@ class ChatRequest(BaseModel): ...@@ -163,6 +156,41 @@ class ChatRequest(BaseModel):
tool_prompt: Optional[str] = None tool_prompt: Optional[str] = None
# Choice of tool to be used # Choice of tool to be used
tool_choice: Optional[str] = None tool_choice: Optional[str] = None
# Stop generating tokens if a member of `stop` is generated
stop: Optional[List[str]] = None
class ChatCompletionComplete(BaseModel):
# Index of the chat completion
index: int
# Message associated with the chat completion
message: Message
# Log probabilities for the chat completion
logprobs: Optional[Any]
# Reason for completion
finish_reason: str
# Usage details of the chat completion
usage: Optional[Any] = None
class ChatComplete(BaseModel):
# Chat completion details
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[ChatCompletionComplete]
usage: Any
class ChatCompletionChunk(BaseModel):
id: str
object: str
created: int
model: str
system_fingerprint: str
choices: List[Choice]
class Parameters(BaseModel): class Parameters(BaseModel):
...@@ -424,5 +452,9 @@ class StreamResponse(BaseModel): ...@@ -424,5 +452,9 @@ class StreamResponse(BaseModel):
# Inference API currently deployed model # Inference API currently deployed model
class DeployedModel(BaseModel): class DeployedModel(BaseModel):
# Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members
# with model_ prefixes, since this disables guardrails for colliding fields:
# https://github.com/pydantic/pydantic/issues/9177
model_config = ConfigDict(protected_namespaces=())
model_id: str model_id: str
sha: str sha: str
Documentation available at: https://huggingface.co/docs/text-generation-inference
## Release
When making a release, please update the latest version in the documentation with:
```
export OLD_VERSION="2\.0\.3"
export NEW_VERSION="2\.0\.4"
find . -name '*.md' -exec sed -i -e "s/$OLD_VERSION/$NEW_VERSION/g" {} \;
```
This diff is collapsed.
...@@ -3,12 +3,22 @@ ...@@ -3,12 +3,22 @@
title: Text Generation Inference title: Text Generation Inference
- local: quicktour - local: quicktour
title: Quick Tour title: Quick Tour
- local: installation_nvidia
title: Using TGI with Nvidia GPUs
- local: installation_amd
title: Using TGI with AMD GPUs
- local: installation_gaudi
title: Using TGI with Intel Gaudi
- local: installation_inferentia
title: Using TGI with AWS Inferentia
- local: installation - local: installation
title: Installation title: Installation from source
- local: supported_models - local: supported_models
title: Supported Models and Hardware title: Supported Models and Hardware
- local: messages_api - local: messages_api
title: Messages API title: Messages API
- local: architecture
title: Internal Architecture
title: Getting started title: Getting started
- sections: - sections:
- local: basic_tutorials/consuming_tgi - local: basic_tutorials/consuming_tgi
...@@ -20,7 +30,7 @@ ...@@ -20,7 +30,7 @@
- local: basic_tutorials/using_cli - local: basic_tutorials/using_cli
title: Using TGI CLI title: Using TGI CLI
- local: basic_tutorials/launcher - local: basic_tutorials/launcher
title: All TGI CLI options title: All TGI CLI options
- local: basic_tutorials/non_core_models - local: basic_tutorials/non_core_models
title: Non-core Model Serving title: Non-core Model Serving
- local: basic_tutorials/safety - local: basic_tutorials/safety
...@@ -29,6 +39,10 @@ ...@@ -29,6 +39,10 @@
title: Using Guidance, JSON, tools title: Using Guidance, JSON, tools
- local: basic_tutorials/visual_language_models - local: basic_tutorials/visual_language_models
title: Visual Language Models title: Visual Language Models
- local: basic_tutorials/monitoring
title: Monitoring TGI with Prometheus and Grafana
- local: basic_tutorials/train_medusa
title: Train Medusa
title: Tutorials title: Tutorials
- sections: - sections:
- local: conceptual/streaming - local: conceptual/streaming
...@@ -46,6 +60,9 @@ ...@@ -46,6 +60,9 @@
- local: conceptual/speculation - local: conceptual/speculation
title: Speculation (Medusa, ngram) title: Speculation (Medusa, ngram)
- local: conceptual/guidance - local: conceptual/guidance
title: How Guidance Works (via outlines) title: How Guidance Works (via outlines
- local: conceptual/lora
title: LoRA (Low-Rank Adaptation)
title: Conceptual Guides title: Conceptual Guides
This diff is collapsed.
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
If the model you wish to serve is behind gated access or the model repository on Hugging Face Hub is private, and you have access to the model, you can provide your Hugging Face Hub access token. You can generate and copy a read token from [Hugging Face Hub tokens page](https://huggingface.co/settings/tokens) If the model you wish to serve is behind gated access or the model repository on Hugging Face Hub is private, and you have access to the model, you can provide your Hugging Face Hub access token. You can generate and copy a read token from [Hugging Face Hub tokens page](https://huggingface.co/settings/tokens)
If you're using the CLI, set the `HUGGING_FACE_HUB_TOKEN` environment variable. For example: If you're using the CLI, set the `HF_TOKEN` environment variable. For example:
``` ```
export HUGGING_FACE_HUB_TOKEN=<YOUR READ TOKEN> export HF_TOKEN=<YOUR READ TOKEN>
``` ```
If you would like to do it through Docker, you can provide your token by specifying `HUGGING_FACE_HUB_TOKEN` as shown below. If you would like to do it through Docker, you can provide your token by specifying `HF_TOKEN` as shown below.
```bash ```bash
model=meta-llama/Llama-2-7b-chat-hf model=meta-llama/Llama-2-7b-chat-hf
...@@ -17,8 +17,8 @@ token=<your READ token> ...@@ -17,8 +17,8 @@ token=<your READ token>
docker run --gpus all \ docker run --gpus all \
--shm-size 1g \ --shm-size 1g \
-e HUGGING_FACE_HUB_TOKEN=$token \ -e HF_TOKEN=$token \
-p 8080:80 \ -p 8080:80 \
-v $volume:/data ghcr.io/huggingface/text-generation-inference:1.4 \ -v $volume:/data ghcr.io/huggingface/text-generation-inference:2.0.4 \
--model-id $model --model-id $model
``` ```
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment