Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
bcb53903
Commit
bcb53903
authored
Oct 15, 2022
by
Olivier Dehaene
Browse files
feat: Add AML deployment
parent
bf99afe9
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
58 additions
and
7 deletions
+58
-7
.dockerignore
.dockerignore
+1
-0
aml/README.md
aml/README.md
+7
-0
aml/deployment.yaml
aml/deployment.yaml
+39
-0
aml/endpoint.yaml
aml/endpoint.yaml
+3
-0
router/src/server.rs
router/src/server.rs
+4
-3
server/bloom_inference/model.py
server/bloom_inference/model.py
+3
-3
server/bloom_inference/prepare_weights.py
server/bloom_inference/prepare_weights.py
+1
-1
No files found.
.dockerignore
View file @
bcb53903
aml
router/target
\ No newline at end of file
aml/README.md
0 → 100644
View file @
bcb53903
```
shell
docker build
.
-t
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1
docker push db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1
az ml online-endpoint create
-f
endpoint.yaml
-g
HuggingFace-BLOOM-ModelPage
-w
HuggingFace
az ml online-deployment create
-f
deployment.yaml
-g
HuggingFace-BLOOM-ModelPage
-w
HuggingFace
```
\ No newline at end of file
aml/deployment.yaml
0 → 100644
View file @
bcb53903
$schema
:
https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json
name
:
bloom-deployment
endpoint_name
:
bloom-inference
model
:
name
:
bloom
path
:
./bloom
model_mount_path
:
/var/azureml-model
environment_variables
:
MODEL_BASE_PATH
:
/var/azureml-model/bloom
MODEL_NAME
:
bigscience/bloom
NUM_GPUS
:
8
environment
:
image
:
db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation:0.1
inference_config
:
liveness_route
:
port
:
3000
path
:
/health
readiness_route
:
port
:
3000
path
:
/health
scoring_route
:
port
:
3000
path
:
/generate
instance_type
:
Standard_ND96amsr_A100_v4
request_settings
:
request_timeout_ms
:
90000
liveness_probe
:
initial_delay
:
300
timeout
:
20
period
:
60
success_threshold
:
1
failure_threshold
:
60
readiness_probe
:
initial_delay
:
300
timeout
:
20
period
:
60
success_threshold
:
1
failure_threshold
:
60
instance_count
:
1
aml/endpoint.yaml
0 → 100644
View file @
bcb53903
$schema
:
https://azuremlsdk2.blob.core.windows.net/latest/managedOnlineEndpoint.schema.json
name
:
bloom-inference
auth_mode
:
aml_token
router/src/server.rs
View file @
bcb53903
use
crate
::{
Batcher
,
ShardedClient
,
Validation
};
use
bloom_inference_client
::
ShardedClient
;
use
crate
::{
Batcher
,
Validation
};
use
axum
::
extract
::
Extension
;
use
axum
::
http
::
StatusCode
;
use
axum
::
routing
::
post
;
use
axum
::
routing
::
{
get
,
post
}
;
use
axum
::{
Json
,
Router
};
use
serde
::
Deserialize
;
use
std
::
net
::
SocketAddr
;
...
...
@@ -142,7 +143,7 @@ pub async fn run(client: ShardedClient, tokenizer: Tokenizer, addr: SocketAddr)
let
app
=
Router
::
new
()
.route
(
"/generate"
,
post
(
generate
))
.layer
(
Extension
(
shared_state
.clone
()))
.route
(
"/health"
,
pos
t
(
liveness
))
.route
(
"/health"
,
ge
t
(
liveness
))
.layer
(
Extension
(
shared_state
.clone
()));
axum
::
Server
::
bind
(
&
addr
)
...
...
server/bloom_inference/model.py
View file @
bcb53903
...
...
@@ -9,7 +9,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from
transformers.modeling_utils
import
no_init_weights
from
bloom_inference.pb
import
generate_pb2
from
bloom_inference.
shard_model
import
shard_model
,
match_suffix
from
bloom_inference.
prepare_weights
import
prepare_weights
,
match_suffix
from
bloom_inference.utils
import
(
StoppingCriteria
,
NextTokenChooser
,
...
...
@@ -377,8 +377,8 @@ class BLOOMSharded(BLOOM):
# shard state_dict
if
self
.
master
:
# TODO @thomasw21 do some caching
shard_state_dict_paths
=
shard_model
(
model_name
,
shard_directory
,
tp_world_size
=
self
.
world_size
,
dtype
=
dtype
shard_state_dict_paths
=
prepare_weights
(
model_name
,
shard_directory
/
"cache"
,
shard_directory
,
tp_world_size
=
self
.
world_size
)
shard_state_dict_paths
=
[
str
(
path
.
absolute
())
for
path
in
shard_state_dict_paths
...
...
server/bloom_inference/prepare_weights.py
View file @
bcb53903
...
...
@@ -62,7 +62,7 @@ def prepare_weights(model_name: str, cache_path: Path, save_path: Path, tp_world
if
all
(
save_path
.
exists
()
for
save_path
in
save_paths
):
print
(
"Weights are already prepared"
)
return
return
save_paths
cache_path
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
if
model_name
==
"bigscience/bloom-560m"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment