Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
b40e8334
Unverified
Commit
b40e8334
authored
Feb 28, 2024
by
OlivierDehaene
Committed by
GitHub
Feb 28, 2024
Browse files
feat: starcoder2 (#1605)
parent
97e22369
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1601 additions
and
16 deletions
+1601
-16
integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json
...pshots__/test_flash_starcoder2/test_flash_starcoder2.json
+94
-0
integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json
...lash_starcoder2/test_flash_starcoder2_default_params.json
+394
-0
integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json
...s__/test_flash_starcoder2/test_flash_starcoder2_load.json
+378
-0
integration-tests/models/test_flash_starcoder2.py
integration-tests/models/test_flash_starcoder2.py
+55
-0
proto/generate.proto
proto/generate.proto
+0
-1
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+24
-0
server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py
...erver/models/custom_modeling/flash_starcoder2_modeling.py
+545
-0
server/text_generation_server/models/flash_mistral.py
server/text_generation_server/models/flash_mistral.py
+25
-15
server/text_generation_server/models/flash_starcoder2.py
server/text_generation_server/models/flash_starcoder2.py
+86
-0
No files found.
integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json
0 → 100644
View file @
b40e8334
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
610
,
"logprob"
:
null
,
"text"
:
"def"
},
{
"id"
:
1489
,
"logprob"
:
-5.2617188
,
"text"
:
" print"
},
{
"id"
:
100
,
"logprob"
:
-0.38476562
,
"text"
:
"_"
},
{
"id"
:
7670
,
"logprob"
:
-7.640625
,
"text"
:
"hello"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
2284
,
"logprob"
:
-0.92626953
,
"special"
:
false
,
"text"
:
"():"
},
{
"id"
:
303
,
"logprob"
:
-0.40844727
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
1489
,
"logprob"
:
-0.27905273
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
459
,
"logprob"
:
-0.6118164
,
"special"
:
false
,
"text"
:
"(
\"
"
},
{
"id"
:
8302
,
"logprob"
:
-0.68652344
,
"special"
:
false
,
"text"
:
"Hello"
},
{
"id"
:
10914
,
"logprob"
:
-1.4619141
,
"special"
:
false
,
"text"
:
" World"
},
{
"id"
:
16013
,
"logprob"
:
-0.7993164
,
"special"
:
false
,
"text"
:
"!
\"
)"
},
{
"id"
:
222
,
"logprob"
:
-0.63134766
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
222
,
"logprob"
:
-0.23278809
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
610
,
"logprob"
:
-1.2294922
,
"special"
:
false
,
"text"
:
"def"
}
],
"top_tokens"
:
null
},
"generated_text"
:
"():
\n
print(
\"
Hello World!
\"
)
\n\n
def"
}
integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json
0 → 100644
View file @
b40e8334
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
60
,
"prefill"
:
[
{
"id"
:
610
,
"logprob"
:
null
,
"text"
:
"def"
},
{
"id"
:
1489
,
"logprob"
:
-5.2617188
,
"text"
:
" print"
},
{
"id"
:
100
,
"logprob"
:
-0.38476562
,
"text"
:
"_"
},
{
"id"
:
7670
,
"logprob"
:
-7.640625
,
"text"
:
"hello"
}
],
"seed"
:
0
,
"tokens"
:
[
{
"id"
:
2284
,
"logprob"
:
-0.296875
,
"special"
:
false
,
"text"
:
"():"
},
{
"id"
:
303
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
1489
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
459
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"(
\"
"
},
{
"id"
:
8302
,
"logprob"
:
-0.28125
,
"special"
:
false
,
"text"
:
"Hello"
},
{
"id"
:
10914
,
"logprob"
:
-0.79248047
,
"special"
:
false
,
"text"
:
" World"
},
{
"id"
:
16013
,
"logprob"
:
-0.61816406
,
"special"
:
false
,
"text"
:
"!
\"
)"
},
{
"id"
:
222
,
"logprob"
:
-0.0619812
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
222
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
610
,
"logprob"
:
-0.4091797
,
"special"
:
false
,
"text"
:
"def"
},
{
"id"
:
1489
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
100
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"_"
},
{
"id"
:
7670
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"hello"
},
{
"id"
:
100
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"_"
},
{
"id"
:
444
,
"logprob"
:
-0.21655273
,
"special"
:
false
,
"text"
:
"name"
},
{
"id"
:
45
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"("
},
{
"id"
:
444
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"name"
},
{
"id"
:
731
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"):"
},
{
"id"
:
303
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
1489
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
459
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"(
\"
"
},
{
"id"
:
8302
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"Hello"
},
{
"id"
:
332
,
"logprob"
:
-0.034698486
,
"special"
:
false
,
"text"
:
"
\"
"
},
{
"id"
:
494
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" +"
},
{
"id"
:
655
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" name"
},
{
"id"
:
494
,
"logprob"
:
-0.20141602
,
"special"
:
false
,
"text"
:
" +"
},
{
"id"
:
332
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\"
"
},
{
"id"
:
16013
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"!
\"
)"
},
{
"id"
:
222
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
222
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
610
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"def"
},
{
"id"
:
1489
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
100
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"_"
},
{
"id"
:
7670
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"hello"
},
{
"id"
:
100
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"_"
},
{
"id"
:
444
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"name"
},
{
"id"
:
100
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"_"
},
{
"id"
:
400
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"age"
},
{
"id"
:
45
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"("
},
{
"id"
:
444
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"name"
},
{
"id"
:
49
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
","
},
{
"id"
:
11505
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" age"
},
{
"id"
:
731
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"):"
},
{
"id"
:
303
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
1489
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
459
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"(
\"
"
},
{
"id"
:
8302
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"Hello"
},
{
"id"
:
332
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\"
"
},
{
"id"
:
494
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" +"
},
{
"id"
:
655
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" name"
},
{
"id"
:
494
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" +"
},
{
"id"
:
3021
,
"logprob"
:
-0.5761719
,
"special"
:
false
,
"text"
:
"
\"
,"
},
{
"id"
:
863
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" you"
},
{
"id"
:
904
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" are"
},
{
"id"
:
332
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\"
"
},
{
"id"
:
494
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" +"
},
{
"id"
:
615
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" str"
},
{
"id"
:
45
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"("
},
{
"id"
:
400
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"age"
},
{
"id"
:
46
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
")"
}
],
"top_tokens"
:
null
},
"generated_text"
:
"():
\n
print(
\"
Hello World!
\"
)
\n\n
def print_hello_name(name):
\n
print(
\"
Hello
\"
+ name +
\"
!
\"
)
\n\n
def print_hello_name_age(name, age):
\n
print(
\"
Hello
\"
+ name +
\"
, you are
\"
+ str(age)"
}
integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json
0 → 100644
View file @
b40e8334
[
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
610
,
"logprob"
:
null
,
"text"
:
"def"
},
{
"id"
:
1489
,
"logprob"
:
-5.2617188
,
"text"
:
" print"
},
{
"id"
:
100
,
"logprob"
:
-0.38476562
,
"text"
:
"_"
},
{
"id"
:
7670
,
"logprob"
:
-7.640625
,
"text"
:
"hello"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
2284
,
"logprob"
:
-0.92626953
,
"special"
:
false
,
"text"
:
"():"
},
{
"id"
:
303
,
"logprob"
:
-0.40722656
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
1489
,
"logprob"
:
-0.27954102
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
459
,
"logprob"
:
-0.6142578
,
"special"
:
false
,
"text"
:
"(
\"
"
},
{
"id"
:
8302
,
"logprob"
:
-0.68310547
,
"special"
:
false
,
"text"
:
"Hello"
},
{
"id"
:
10914
,
"logprob"
:
-1.4570312
,
"special"
:
false
,
"text"
:
" World"
},
{
"id"
:
16013
,
"logprob"
:
-0.80126953
,
"special"
:
false
,
"text"
:
"!
\"
)"
},
{
"id"
:
222
,
"logprob"
:
-0.6303711
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
222
,
"logprob"
:
-0.23327637
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
610
,
"logprob"
:
-1.2304688
,
"special"
:
false
,
"text"
:
"def"
}
],
"top_tokens"
:
null
},
"generated_text"
:
"():
\n
print(
\"
Hello World!
\"
)
\n\n
def"
},
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
610
,
"logprob"
:
null
,
"text"
:
"def"
},
{
"id"
:
1489
,
"logprob"
:
-5.2617188
,
"text"
:
" print"
},
{
"id"
:
100
,
"logprob"
:
-0.38476562
,
"text"
:
"_"
},
{
"id"
:
7670
,
"logprob"
:
-7.640625
,
"text"
:
"hello"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
2284
,
"logprob"
:
-0.92626953
,
"special"
:
false
,
"text"
:
"():"
},
{
"id"
:
303
,
"logprob"
:
-0.40722656
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
1489
,
"logprob"
:
-0.27954102
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
459
,
"logprob"
:
-0.6142578
,
"special"
:
false
,
"text"
:
"(
\"
"
},
{
"id"
:
8302
,
"logprob"
:
-0.68310547
,
"special"
:
false
,
"text"
:
"Hello"
},
{
"id"
:
10914
,
"logprob"
:
-1.4570312
,
"special"
:
false
,
"text"
:
" World"
},
{
"id"
:
16013
,
"logprob"
:
-0.80126953
,
"special"
:
false
,
"text"
:
"!
\"
)"
},
{
"id"
:
222
,
"logprob"
:
-0.6303711
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
222
,
"logprob"
:
-0.23327637
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
610
,
"logprob"
:
-1.2304688
,
"special"
:
false
,
"text"
:
"def"
}
],
"top_tokens"
:
null
},
"generated_text"
:
"():
\n
print(
\"
Hello World!
\"
)
\n\n
def"
},
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
610
,
"logprob"
:
null
,
"text"
:
"def"
},
{
"id"
:
1489
,
"logprob"
:
-5.2617188
,
"text"
:
" print"
},
{
"id"
:
100
,
"logprob"
:
-0.38476562
,
"text"
:
"_"
},
{
"id"
:
7670
,
"logprob"
:
-7.640625
,
"text"
:
"hello"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
2284
,
"logprob"
:
-0.92626953
,
"special"
:
false
,
"text"
:
"():"
},
{
"id"
:
303
,
"logprob"
:
-0.40722656
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
1489
,
"logprob"
:
-0.27954102
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
459
,
"logprob"
:
-0.6142578
,
"special"
:
false
,
"text"
:
"(
\"
"
},
{
"id"
:
8302
,
"logprob"
:
-0.68310547
,
"special"
:
false
,
"text"
:
"Hello"
},
{
"id"
:
10914
,
"logprob"
:
-1.4570312
,
"special"
:
false
,
"text"
:
" World"
},
{
"id"
:
16013
,
"logprob"
:
-0.80126953
,
"special"
:
false
,
"text"
:
"!
\"
)"
},
{
"id"
:
222
,
"logprob"
:
-0.6303711
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
222
,
"logprob"
:
-0.23327637
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
610
,
"logprob"
:
-1.2304688
,
"special"
:
false
,
"text"
:
"def"
}
],
"top_tokens"
:
null
},
"generated_text"
:
"():
\n
print(
\"
Hello World!
\"
)
\n\n
def"
},
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
610
,
"logprob"
:
null
,
"text"
:
"def"
},
{
"id"
:
1489
,
"logprob"
:
-5.2617188
,
"text"
:
" print"
},
{
"id"
:
100
,
"logprob"
:
-0.38476562
,
"text"
:
"_"
},
{
"id"
:
7670
,
"logprob"
:
-7.640625
,
"text"
:
"hello"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
2284
,
"logprob"
:
-0.92626953
,
"special"
:
false
,
"text"
:
"():"
},
{
"id"
:
303
,
"logprob"
:
-0.40722656
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
1489
,
"logprob"
:
-0.27954102
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
459
,
"logprob"
:
-0.6142578
,
"special"
:
false
,
"text"
:
"(
\"
"
},
{
"id"
:
8302
,
"logprob"
:
-0.68310547
,
"special"
:
false
,
"text"
:
"Hello"
},
{
"id"
:
10914
,
"logprob"
:
-1.4570312
,
"special"
:
false
,
"text"
:
" World"
},
{
"id"
:
16013
,
"logprob"
:
-0.80126953
,
"special"
:
false
,
"text"
:
"!
\"
)"
},
{
"id"
:
222
,
"logprob"
:
-0.6303711
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
222
,
"logprob"
:
-0.23327637
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
610
,
"logprob"
:
-1.2304688
,
"special"
:
false
,
"text"
:
"def"
}
],
"top_tokens"
:
null
},
"generated_text"
:
"():
\n
print(
\"
Hello World!
\"
)
\n\n
def"
}
]
integration-tests/models/test_flash_starcoder2.py
0 → 100644
View file @
b40e8334
import
pytest
@
pytest
.
fixture
(
scope
=
"module"
)
def
flash_starcoder2_handle
(
launcher
):
with
launcher
(
"bigcode/starcoder2-3b"
,
num_shard
=
2
)
as
handle
:
yield
handle
@
pytest
.
fixture
(
scope
=
"module"
)
async
def
flash_starcoder2
(
flash_starcoder2_handle
):
await
flash_starcoder2_handle
.
health
(
300
)
return
flash_starcoder2_handle
.
client
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_starcoder2
(
flash_starcoder2
,
response_snapshot
):
response
=
await
flash_starcoder2
.
generate
(
"def print_hello"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
==
response_snapshot
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_starcoder2_default_params
(
flash_starcoder2
,
response_snapshot
):
response
=
await
flash_starcoder2
.
generate
(
"def print_hello"
,
max_new_tokens
=
60
,
temperature
=
0.2
,
top_p
=
0.95
,
decoder_input_details
=
True
,
seed
=
0
,
)
assert
response
.
details
.
generated_tokens
==
60
assert
response
==
response_snapshot
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_starcoder2_load
(
flash_starcoder2
,
generate_load
,
response_snapshot
):
responses
=
await
generate_load
(
flash_starcoder2
,
"def print_hello"
,
max_new_tokens
=
10
,
n
=
4
)
assert
len
(
responses
)
==
4
assert
all
([
r
.
generated_text
==
responses
[
0
].
generated_text
for
r
in
responses
])
assert
responses
==
response_snapshot
proto/generate.proto
View file @
b40e8334
...
@@ -230,7 +230,6 @@ message WarmupRequest {
...
@@ -230,7 +230,6 @@ message WarmupRequest {
uint32
max_total_tokens
=
4
;
uint32
max_total_tokens
=
4
;
}
}
/// Empty response
message
WarmupResponse
{
message
WarmupResponse
{
/// Maximum number of tokens supported by the model
/// Maximum number of tokens supported by the model
optional
uint32
max_supported_total_tokens
=
1
;
optional
uint32
max_supported_total_tokens
=
1
;
...
...
server/text_generation_server/models/__init__.py
View file @
b40e8334
...
@@ -64,6 +64,7 @@ try:
...
@@ -64,6 +64,7 @@ try:
from
text_generation_server.models.flash_mistral
import
FlashMistral
from
text_generation_server.models.flash_mistral
import
FlashMistral
from
text_generation_server.models.flash_mixtral
import
FlashMixtral
from
text_generation_server.models.flash_mixtral
import
FlashMixtral
from
text_generation_server.models.flash_phi
import
FlashPhi
from
text_generation_server.models.flash_phi
import
FlashPhi
from
text_generation_server.models.flash_starcoder2
import
FlashStarcoder2
from
text_generation_server.utils.flash_attn
import
HAS_FLASH_ATTN_V2_CUDA
from
text_generation_server.utils.flash_attn
import
HAS_FLASH_ATTN_V2_CUDA
except
ImportError
as
e
:
except
ImportError
as
e
:
...
@@ -80,6 +81,7 @@ if FLASH_ATTENTION:
...
@@ -80,6 +81,7 @@ if FLASH_ATTENTION:
__all__
.
append
(
FlashMistral
)
__all__
.
append
(
FlashMistral
)
__all__
.
append
(
FlashMixtral
)
__all__
.
append
(
FlashMixtral
)
__all__
.
append
(
FlashPhi
)
__all__
.
append
(
FlashPhi
)
__all__
.
append
(
FlashStarcoder2
)
MAMBA_AVAILABLE
=
True
MAMBA_AVAILABLE
=
True
try
:
try
:
...
@@ -184,6 +186,16 @@ def get_model(
...
@@ -184,6 +186,16 @@ def get_model(
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
if
model_id
.
startswith
(
"facebook/galactica"
):
return
GalacticaSharded
(
model_id
,
revision
,
quantize
=
quantize
,
use_medusa
=
use_medusa
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
if
(
if
(
model_type
==
"gpt_bigcode"
model_type
==
"gpt_bigcode"
or
model_type
==
"gpt2"
or
model_type
==
"gpt2"
...
@@ -401,6 +413,18 @@ def get_model(
...
@@ -401,6 +413,18 @@ def get_model(
dtype
=
dtype
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
if
model_type
==
"starcoder2"
:
sliding_window
=
config_dict
.
get
(
"sliding_window"
,
-
1
)
if
(
(
sliding_window
is
None
or
sliding_window
==
-
1
)
and
FLASH_ATTENTION
)
or
HAS_FLASH_ATTN_V2_CUDA
:
return
FlashStarcoder2
(
model_id
,
revision
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
if
model_type
==
"opt"
:
if
model_type
==
"opt"
:
return
OPTSharded
(
return
OPTSharded
(
...
...
server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py
0 → 100644
View file @
b40e8334
# coding=utf-8
# Copyright 2024 Starcoder2 AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.distributed
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
typing
import
Optional
,
List
,
Tuple
from
text_generation_server.utils
import
paged_attention
,
flash_attn
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
TensorParallelEmbedding
,
PositionRotaryEmbedding
,
SpeculativeHead
,
get_linear
,
FastRMSNorm
,
FastLayerNorm
,
)
class
Starcoder2Config
(
PretrainedConfig
):
model_type
=
"starcoder2"
def
__init__
(
self
,
vocab_size
=
49152
,
hidden_size
=
3072
,
intermediate_size
=
12288
,
num_hidden_layers
=
30
,
num_attention_heads
=
24
,
num_key_value_heads
=
2
,
mlp_type
=
"default"
,
hidden_act
=
"gelu_pytorch_tanh"
,
max_position_embeddings
=
4096
,
initializer_range
=
0.018042
,
norm_type
=
"layer_norm"
,
norm_epsilon
=
1e-5
,
use_cache
=
True
,
bos_token_id
=
50256
,
eos_token_id
=
50256
,
rope_theta
=
10000.0
,
sliding_window
=
None
,
attention_dropout
=
0.0
,
residual_dropout
=
0.0
,
embedding_dropout
=
0.0
,
use_bias
:
bool
=
True
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
sliding_window
=
sliding_window
self
.
use_bias
=
use_bias
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
mlp_type
=
mlp_type
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
norm_type
=
norm_type
self
.
norm_epsilon
=
norm_epsilon
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
attention_dropout
=
attention_dropout
self
.
residual_dropout
=
residual_dropout
self
.
embedding_dropout
=
embedding_dropout
super
().
__init__
(
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
**
kwargs
,
)
def
load_attention
(
config
,
prefix
,
weights
):
if
config
.
num_attention_heads
!=
config
.
num_key_value_heads
:
return
_load_gqa
(
config
,
prefix
,
weights
)
else
:
return
TensorParallelColumnLinear
.
load_multi
(
config
,
prefixes
=
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
],
dim
=
0
,
weights
=
weights
,
bias
=
config
.
use_bias
,
)
def
_load_gqa
(
config
,
prefix
:
str
,
weights
):
assert
config
.
hidden_size
%
config
.
num_attention_heads
==
0
assert
config
.
num_attention_heads
%
weights
.
process_group
.
size
()
==
0
weight
=
weights
.
get_multi_weights_col
(
prefixes
=
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
],
quantize
=
config
.
quantize
,
dim
=
0
,
)
if
config
.
quantize
not
in
[
"gptq"
,
"awq"
]:
weight
=
weight
.
to
(
dtype
=
weights
.
dtype
).
to
(
device
=
weights
.
device
)
head_size
=
config
.
hidden_size
//
config
.
num_attention_heads
num_heads
=
config
.
num_attention_heads
//
weights
.
process_group
.
size
()
num_key_value_heads
=
config
.
num_key_value_heads
//
weights
.
process_group
.
size
()
assert
list
(
weight
.
shape
)
==
[
(
num_heads
+
2
*
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
,
],
f
"
{
list
(
weight
.
shape
)
}
!=
{
[(
num_heads
+
2
*
config
.
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
]
}
"
if
config
.
use_bias
:
w
=
[
weights
.
get_sharded
(
f
"
{
p
}
.bias"
,
dim
=
0
)
for
p
in
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
]
]
bias
=
torch
.
cat
(
w
,
dim
=
0
).
to
(
dtype
=
weights
.
dtype
).
to
(
device
=
weights
.
device
)
else
:
bias
=
None
return
TensorParallelColumnLinear
(
get_linear
(
weight
,
bias
=
bias
,
quantize
=
config
.
quantize
)
)
class
Starcoder2Attention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
prefix
:
str
,
config
,
weights
,
):
super
().
__init__
()
self
.
max_past
=
(
config
.
sliding_window
if
config
.
sliding_window
is
not
None
else
-
1
)
self
.
num_heads
=
config
.
num_attention_heads
self
.
hidden_size
=
config
.
hidden_size
self
.
head_size
=
self
.
hidden_size
//
self
.
num_heads
self
.
rotary_emb
=
PositionRotaryEmbedding
.
static
(
config
=
config
,
dim
=
self
.
head_size
,
base
=
config
.
rope_theta
,
device
=
weights
.
device
,
)
self
.
softmax_scale
=
self
.
head_size
**-
0.5
if
self
.
num_heads
%
weights
.
process_group
.
size
()
!=
0
:
raise
ValueError
(
f
"`num_heads` must be divisible by `num_shards` (got `num_heads`:
{
self
.
num_heads
}
"
f
"and `num_shards`:
{
weights
.
process_group
.
size
()
}
"
)
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
num_key_value_heads
=
(
config
.
num_key_value_heads
//
weights
.
process_group
.
size
()
)
self
.
query_key_value
=
load_attention
(
config
,
prefix
,
weights
)
self
.
o_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
weights
=
weights
,
bias
=
config
.
use_bias
,
)
self
.
num_groups
=
self
.
num_heads
//
self
.
num_key_value_heads
self
.
kv_head_mapping
=
torch
.
arange
(
0
,
self
.
num_key_value_heads
,
dtype
=
torch
.
int32
,
device
=
weights
.
device
).
repeat_interleave
(
self
.
num_groups
)
def
forward
(
self
,
hidden_states
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
prefill_cache_indices
,
):
qkv
=
self
.
query_key_value
(
hidden_states
)
query
,
kv
=
qkv
.
split
(
[
self
.
head_size
*
self
.
num_heads
,
2
*
self
.
head_size
*
self
.
num_key_value_heads
,
],
dim
=
1
,
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
kv
=
kv
.
view
(
-
1
,
2
,
self
.
num_key_value_heads
,
self
.
head_size
)
self
.
rotary_emb
(
query
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
cos
,
sin
)
if
prefill_cache_indices
is
not
None
:
kv_to_cache
=
kv
[
prefill_cache_indices
]
else
:
kv_to_cache
=
kv
paged_attention
.
reshape_and_cache
(
kv_to_cache
[:,
0
],
kv_to_cache
[:,
1
],
kv_cache
[
0
],
kv_cache
[
1
],
slots
)
# output tensor
attn_output
=
torch
.
empty_like
(
query
)
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# flash attention
flash_attn
.
attention
(
query
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
torch
.
select
(
kv
,
dim
=
1
,
index
=
1
),
attn_output
,
cu_seqlen_prefill
,
max_s
,
self
.
softmax_scale
,
window_size_left
=
self
.
max_past
,
)
# Decode
else
:
paged_attention
.
attention
(
attn_output
,
query
,
kv_cache
[
0
],
kv_cache
[
1
],
self
.
kv_head_mapping
,
self
.
softmax_scale
,
block_tables
,
input_lengths
,
max_s
,
)
return
self
.
o_proj
(
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
))
class
Starcoder2MLP
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
super
().
__init__
()
act
=
config
.
hidden_act
self
.
act
=
(
ACT2FN
[
act
]
if
"gelu"
not
in
act
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
(
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
"none"
),
)
)
# Fuse gate and up proj
self
.
c_fc
=
TensorParallelColumnLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.c_fc"
,
weights
=
weights
,
bias
=
config
.
use_bias
,
)
self
.
c_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.c_proj"
,
weights
=
weights
,
bias
=
config
.
use_bias
,
)
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
return
self
.
c_proj
(
hidden_states
)
class
Starcoder2GatedMLP
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
super
().
__init__
()
act
=
config
.
hidden_act
self
.
act
=
(
ACT2FN
[
act
]
if
"gelu"
not
in
act
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
(
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
"none"
),
)
)
# Fuse gate and up proj
self
.
gate_up_proj
=
TensorParallelColumnLinear
.
load_multi
(
config
,
prefixes
=
[
f
"
{
prefix
}
.gate_proj"
,
f
"
{
prefix
}
.up_proj"
],
weights
=
weights
,
dim
=
0
,
bias
=
config
.
use_bias
,
)
self
.
down_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
weights
=
weights
,
bias
=
config
.
use_bias
,
)
self
.
intermediate_size
=
(
config
.
intermediate_size
//
weights
.
process_group
.
size
()
)
def
forward
(
self
,
hidden_states
):
gate_up_states
=
self
.
gate_up_proj
(
hidden_states
)
gate_up_states
=
gate_up_states
.
view
(
-
1
,
2
,
self
.
intermediate_size
)
return
self
.
down_proj
(
self
.
act
(
gate_up_states
[:,
0
])
*
gate_up_states
[:,
1
])
STARCODER2_NORMALIZATION_CLASSES
=
{
"layer_norm"
:
FastLayerNorm
,
"rms_norm"
:
FastRMSNorm
,
}
STARCODER2_MLP_CLASSES
=
{
"default"
:
Starcoder2MLP
,
"gated"
:
Starcoder2GatedMLP
,
}
class
Starcoder2Layer
(
nn
.
Module
):
def
__init__
(
self
,
layer_id
,
config
,
weights
):
super
().
__init__
()
prefix
=
f
"model.layers.
{
layer_id
}
"
self
.
self_attn
=
Starcoder2Attention
(
prefix
=
f
"
{
prefix
}
.self_attn"
,
config
=
config
,
weights
=
weights
)
self
.
mlp
=
STARCODER2_MLP_CLASSES
[
config
.
mlp_type
](
prefix
=
f
"
{
prefix
}
.mlp"
,
config
=
config
,
weights
=
weights
)
self
.
input_layernorm
=
STARCODER2_NORMALIZATION_CLASSES
[
config
.
norm_type
].
load
(
prefix
=
f
"
{
prefix
}
.input_layernorm"
,
weights
=
weights
,
eps
=
config
.
norm_epsilon
)
self
.
post_attention_layernorm
=
STARCODER2_NORMALIZATION_CLASSES
[
config
.
norm_type
].
load
(
prefix
=
f
"
{
prefix
}
.post_attention_layernorm"
,
weights
=
weights
,
eps
=
config
.
norm_epsilon
,
)
def
forward
(
self
,
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
prefill_cache_indices
,
):
normed_hidden_states
,
res
=
self
.
input_layernorm
(
hidden_states
,
residual
)
# Self Attention
attn_output
=
self
.
self_attn
(
normed_hidden_states
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
prefill_cache_indices
,
)
# faster post attention rms norm
normed_attn_res_output
,
attn_res
=
self
.
post_attention_layernorm
(
attn_output
,
res
)
mlp_output
=
self
.
mlp
(
normed_attn_res_output
)
return
mlp_output
,
attn_res
class
Starcoder2Model
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
weights
):
super
().
__init__
()
process_group
=
weights
.
process_group
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
self
.
embed_tokens
=
TensorParallelEmbedding
(
prefix
=
"model.embed_tokens"
,
weights
=
weights
)
self
.
layers
=
nn
.
ModuleList
(
[
Starcoder2Layer
(
layer_id
,
config
,
weights
,
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
STARCODER2_NORMALIZATION_CLASSES
[
config
.
norm_type
].
load
(
prefix
=
"model.norm"
,
weights
=
weights
,
eps
=
config
.
norm_epsilon
)
self
.
gradient_checkpointing
=
False
self
.
head_size
=
self
.
layers
[
0
].
self_attn
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
self_attn
.
num_heads
self
.
num_key_value_heads
=
self
.
layers
[
0
].
self_attn
.
num_key_value_heads
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
],
kv_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
block_tables
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
true_max_s
:
int
,
prefill_cache_indices
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos
,
sin
=
self
.
layers
[
0
].
self_attn
.
rotary_emb
.
get_cos_sin
(
position_ids
,
true_max_s
,
hidden_states
.
dtype
)
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
[
i
],
block_tables
,
slots
,
input_lengths
,
max_s
,
prefill_cache_indices
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
FlashStarcoder2ForCausalLM
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
weights
):
super
().
__init__
()
self
.
model
=
Starcoder2Model
(
config
,
weights
)
try
:
self
.
lm_head
=
SpeculativeHead
.
load
(
config
,
prefix
=
"lm_head"
,
weights
=
weights
,
)
except
RuntimeError
:
self
.
lm_head
=
SpeculativeHead
.
load
(
config
,
prefix
=
"model.embed_tokens"
,
weights
=
weights
,
)
self
.
max_past
=
config
.
sliding_window
self
.
max_past_tensor
=
(
torch
.
tensor
(
config
.
sliding_window
,
device
=
weights
.
device
)
if
self
.
max_past
is
not
None
else
None
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
],
kv_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
block_tables
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
prefill_cache_indices
:
Optional
[
torch
.
Tensor
],
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
true_max_s
=
max_s
if
prefill_cache_indices
is
not
None
:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots
=
slots
[
prefill_cache_indices
]
elif
self
.
max_past
is
not
None
:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths
=
torch
.
clamp
(
input_lengths
,
max
=
self
.
max_past_tensor
)
hidden_states
=
self
.
model
(
input_ids
,
position_ids
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
true_max_s
,
prefill_cache_indices
,
)
if
lm_head_indices
is
not
None
:
hidden_states
=
hidden_states
[
lm_head_indices
]
logits
=
self
.
lm_head
(
hidden_states
)
return
logits
server/text_generation_server/models/flash_mistral.py
View file @
b40e8334
...
@@ -8,7 +8,7 @@ from dataclasses import dataclass
...
@@ -8,7 +8,7 @@ from dataclasses import dataclass
from
opentelemetry
import
trace
from
opentelemetry
import
trace
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
from
transformers.models.llama
import
LlamaTokenizerFast
from
transformers.models.llama
import
LlamaTokenizerFast
from
typing
import
Optional
,
Tuple
,
Type
,
List
from
typing
import
Optional
,
Tuple
,
Type
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models
import
FlashCausalLM
...
@@ -38,6 +38,19 @@ SLIDING_WINDOW_BLOCKS: Optional[int] = None
...
@@ -38,6 +38,19 @@ SLIDING_WINDOW_BLOCKS: Optional[int] = None
MEM_POOL
=
torch
.
cuda
.
graph_pool_handle
()
MEM_POOL
=
torch
.
cuda
.
graph_pool_handle
()
def
set_sliding_window
(
sliding_window
:
int
,
sliding_window_blocks
:
int
):
global
SLIDING_WINDOW
global
SLIDING_WINDOW_BLOCKS
SLIDING_WINDOW
=
sliding_window
SLIDING_WINDOW_BLOCKS
=
sliding_window_blocks
def
get_sliding_windows
()
->
Tuple
[
int
,
int
]:
global
SLIDING_WINDOW
global
SLIDING_WINDOW_BLOCKS
return
SLIDING_WINDOW
,
SLIDING_WINDOW_BLOCKS
# Adds windowing logic to FlashCausalLMBatch
# Adds windowing logic to FlashCausalLMBatch
@
dataclass
@
dataclass
class
FlashMistralBatch
(
FlashCausalLMBatch
):
class
FlashMistralBatch
(
FlashCausalLMBatch
):
...
@@ -53,8 +66,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
...
@@ -53,8 +66,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
device
:
torch
.
device
,
)
->
"FlashCausalLMBatch"
:
)
->
"FlashCausalLMBatch"
:
global
SLIDING_WINDOW
sliding_window
,
sliding_window_blocks
=
get_sliding_windows
()
global
SLIDING_WINDOW_BLOCKS
batch_inputs
=
[]
batch_inputs
=
[]
max_truncation
=
0
max_truncation
=
0
...
@@ -139,8 +151,8 @@ class FlashMistralBatch(FlashCausalLMBatch):
...
@@ -139,8 +151,8 @@ class FlashMistralBatch(FlashCausalLMBatch):
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks
=
math
.
ceil
(
total_tokens
/
BLOCK_SIZE
)
needed_blocks
=
math
.
ceil
(
total_tokens
/
BLOCK_SIZE
)
if
SLIDING_WINDOW_BLOCKS
is
not
None
:
if
sliding_window_blocks
is
not
None
:
needed_blocks
=
min
(
needed_blocks
,
SLIDING_WINDOW_BLOCKS
)
needed_blocks
=
min
(
needed_blocks
,
sliding_window_blocks
)
blocks
+=
needed_blocks
blocks
+=
needed_blocks
needed_blocks_slots
.
append
((
needed_blocks
,
total_tokens
))
needed_blocks_slots
.
append
((
needed_blocks
,
total_tokens
))
...
@@ -154,9 +166,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
...
@@ -154,9 +166,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
slot_indices
.
append
(
request_slot_indices
)
slot_indices
.
append
(
request_slot_indices
)
# Create tensor to slice into the kv tensor in prefill
# Create tensor to slice into the kv tensor in prefill
if
SLIDING_WINDOW
is
not
None
:
if
sliding_window
is
not
None
:
request_prefill_cache_indices
=
torch
.
arange
(
request_prefill_cache_indices
=
torch
.
arange
(
cumulative_length
+
max
(
0
,
input_length
-
SLIDING_WINDOW
),
cumulative_length
+
max
(
0
,
input_length
-
sliding_window
),
cumulative_length
+
input_length
,
cumulative_length
+
input_length
,
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
)
)
...
@@ -212,13 +224,13 @@ class FlashMistralBatch(FlashCausalLMBatch):
...
@@ -212,13 +224,13 @@ class FlashMistralBatch(FlashCausalLMBatch):
input_ids
=
np
.
concatenate
(
all_input_ids
,
dtype
=
np
.
int64
)
input_ids
=
np
.
concatenate
(
all_input_ids
,
dtype
=
np
.
int64
)
position_ids
=
torch
.
cat
(
position_ids
)
position_ids
=
torch
.
cat
(
position_ids
)
slot_indices
=
torch
.
cat
(
slot_indices
)
slot_indices
=
torch
.
cat
(
slot_indices
)
if
SLIDING_WINDOW
is
not
None
:
if
sliding_window
is
not
None
:
prefill_cache_indices
=
torch
.
cat
(
prefill_cache_indices
)
prefill_cache_indices
=
torch
.
cat
(
prefill_cache_indices
)
else
:
else
:
input_ids
=
all_input_ids
[
0
]
input_ids
=
all_input_ids
[
0
]
position_ids
=
position_ids
[
0
]
position_ids
=
position_ids
[
0
]
slot_indices
=
slot_indices
[
0
]
slot_indices
=
slot_indices
[
0
]
if
SLIDING_WINDOW
is
not
None
:
if
sliding_window
is
not
None
:
prefill_cache_indices
=
prefill_cache_indices
[
0
]
prefill_cache_indices
=
prefill_cache_indices
[
0
]
cu_seqlen_prefill
=
torch
.
tensor
(
cu_seqlen_prefill
=
torch
.
tensor
(
...
@@ -228,7 +240,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
...
@@ -228,7 +240,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
position_ids
=
position_ids
.
to
(
device
)
position_ids
=
position_ids
.
to
(
device
)
slot_indices
=
slot_indices
.
to
(
device
)
slot_indices
=
slot_indices
.
to
(
device
)
prefill_cache_indices
=
(
prefill_cache_indices
=
(
prefill_cache_indices
.
to
(
device
)
if
SLIDING_WINDOW
is
not
None
else
None
prefill_cache_indices
.
to
(
device
)
if
sliding_window
is
not
None
else
None
)
)
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int64
,
device
=
device
)
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int64
,
device
=
device
)
input_lengths_tensor
=
torch
.
tensor
(
input_lengths_tensor
=
torch
.
tensor
(
...
@@ -298,9 +310,6 @@ class BaseFlashMistral(FlashCausalLM):
...
@@ -298,9 +310,6 @@ class BaseFlashMistral(FlashCausalLM):
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
trust_remote_code
:
bool
=
False
,
):
):
global
SLIDING_WINDOW
global
SLIDING_WINDOW_BLOCKS
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
...
@@ -324,8 +333,9 @@ class BaseFlashMistral(FlashCausalLM):
...
@@ -324,8 +333,9 @@ class BaseFlashMistral(FlashCausalLM):
# Set context windows
# Set context windows
if
config
.
sliding_window
is
not
None
:
if
config
.
sliding_window
is
not
None
:
SLIDING_WINDOW
=
config
.
sliding_window
set_sliding_window
(
SLIDING_WINDOW_BLOCKS
=
math
.
ceil
(
config
.
sliding_window
/
BLOCK_SIZE
)
config
.
sliding_window
,
math
.
ceil
(
config
.
sliding_window
/
BLOCK_SIZE
)
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
...
...
server/text_generation_server/models/flash_starcoder2.py
0 → 100644
View file @
b40e8334
import
math
import
torch
from
typing
import
Optional
from
transformers.models.gpt2
import
GPT2TokenizerFast
from
text_generation_server.models.cache_manager
import
BLOCK_SIZE
from
text_generation_server.models.flash_mistral
import
(
BaseFlashMistral
,
set_sliding_window
,
)
from
text_generation_server.models.custom_modeling.flash_starcoder2_modeling
import
(
Starcoder2Config
,
FlashStarcoder2ForCausalLM
,
)
from
text_generation_server.utils
import
(
initialize_torch_distributed
,
weight_files
,
Weights
,
)
# Starcoder2 has the same base as Mistral
class
FlashStarcoder2
(
BaseFlashMistral
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
use_medusa
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
raise
NotImplementedError
(
"FlashLlama is only available on GPU"
)
tokenizer
=
GPT2TokenizerFast
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
,
trust_remote_code
=
trust_remote_code
,
)
config
=
Starcoder2Config
.
from_pretrained
(
model_id
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
)
config
.
quantize
=
quantize
config
.
use_medusa
=
use_medusa
# Set context windows
if
config
.
sliding_window
is
not
None
:
set_sliding_window
(
config
.
sliding_window
,
math
.
ceil
(
config
.
sliding_window
/
BLOCK_SIZE
)
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
weights
=
Weights
(
filenames
,
device
,
dtype
,
process_group
=
self
.
process_group
)
if
config
.
quantize
in
[
"gptq"
,
"awq"
]:
weights
.
_set_gptq_params
(
model_id
,
revision
)
model
=
FlashStarcoder2ForCausalLM
(
config
,
weights
)
self
.
cuda_graphs
=
{}
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
BaseFlashMistral
,
self
).
__init__
(
model
=
model
,
tokenizer
=
tokenizer
,
num_layers
=
len
(
model
.
model
.
layers
),
num_kv_heads
=
model
.
model
.
num_key_value_heads
,
head_size
=
model
.
model
.
head_size
,
dtype
=
dtype
,
device
=
device
,
rank
=
rank
,
world_size
=
world_size
,
sliding_window
=
config
.
sliding_window
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment