Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
b40e8334
Unverified
Commit
b40e8334
authored
Feb 28, 2024
by
OlivierDehaene
Committed by
GitHub
Feb 28, 2024
Browse files
feat: starcoder2 (#1605)
parent
97e22369
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1601 additions
and
16 deletions
+1601
-16
integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json
...pshots__/test_flash_starcoder2/test_flash_starcoder2.json
+94
-0
integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json
...lash_starcoder2/test_flash_starcoder2_default_params.json
+394
-0
integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json
...s__/test_flash_starcoder2/test_flash_starcoder2_load.json
+378
-0
integration-tests/models/test_flash_starcoder2.py
integration-tests/models/test_flash_starcoder2.py
+55
-0
proto/generate.proto
proto/generate.proto
+0
-1
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+24
-0
server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py
...erver/models/custom_modeling/flash_starcoder2_modeling.py
+545
-0
server/text_generation_server/models/flash_mistral.py
server/text_generation_server/models/flash_mistral.py
+25
-15
server/text_generation_server/models/flash_starcoder2.py
server/text_generation_server/models/flash_starcoder2.py
+86
-0
No files found.
integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2.json
0 → 100644
View file @
b40e8334
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
610
,
"logprob"
:
null
,
"text"
:
"def"
},
{
"id"
:
1489
,
"logprob"
:
-5.2617188
,
"text"
:
" print"
},
{
"id"
:
100
,
"logprob"
:
-0.38476562
,
"text"
:
"_"
},
{
"id"
:
7670
,
"logprob"
:
-7.640625
,
"text"
:
"hello"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
2284
,
"logprob"
:
-0.92626953
,
"special"
:
false
,
"text"
:
"():"
},
{
"id"
:
303
,
"logprob"
:
-0.40844727
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
1489
,
"logprob"
:
-0.27905273
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
459
,
"logprob"
:
-0.6118164
,
"special"
:
false
,
"text"
:
"(
\"
"
},
{
"id"
:
8302
,
"logprob"
:
-0.68652344
,
"special"
:
false
,
"text"
:
"Hello"
},
{
"id"
:
10914
,
"logprob"
:
-1.4619141
,
"special"
:
false
,
"text"
:
" World"
},
{
"id"
:
16013
,
"logprob"
:
-0.7993164
,
"special"
:
false
,
"text"
:
"!
\"
)"
},
{
"id"
:
222
,
"logprob"
:
-0.63134766
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
222
,
"logprob"
:
-0.23278809
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
610
,
"logprob"
:
-1.2294922
,
"special"
:
false
,
"text"
:
"def"
}
],
"top_tokens"
:
null
},
"generated_text"
:
"():
\n
print(
\"
Hello World!
\"
)
\n\n
def"
}
integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_default_params.json
0 → 100644
View file @
b40e8334
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
60
,
"prefill"
:
[
{
"id"
:
610
,
"logprob"
:
null
,
"text"
:
"def"
},
{
"id"
:
1489
,
"logprob"
:
-5.2617188
,
"text"
:
" print"
},
{
"id"
:
100
,
"logprob"
:
-0.38476562
,
"text"
:
"_"
},
{
"id"
:
7670
,
"logprob"
:
-7.640625
,
"text"
:
"hello"
}
],
"seed"
:
0
,
"tokens"
:
[
{
"id"
:
2284
,
"logprob"
:
-0.296875
,
"special"
:
false
,
"text"
:
"():"
},
{
"id"
:
303
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
1489
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
459
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"(
\"
"
},
{
"id"
:
8302
,
"logprob"
:
-0.28125
,
"special"
:
false
,
"text"
:
"Hello"
},
{
"id"
:
10914
,
"logprob"
:
-0.79248047
,
"special"
:
false
,
"text"
:
" World"
},
{
"id"
:
16013
,
"logprob"
:
-0.61816406
,
"special"
:
false
,
"text"
:
"!
\"
)"
},
{
"id"
:
222
,
"logprob"
:
-0.0619812
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
222
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
610
,
"logprob"
:
-0.4091797
,
"special"
:
false
,
"text"
:
"def"
},
{
"id"
:
1489
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
100
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"_"
},
{
"id"
:
7670
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"hello"
},
{
"id"
:
100
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"_"
},
{
"id"
:
444
,
"logprob"
:
-0.21655273
,
"special"
:
false
,
"text"
:
"name"
},
{
"id"
:
45
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"("
},
{
"id"
:
444
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"name"
},
{
"id"
:
731
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"):"
},
{
"id"
:
303
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
1489
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
459
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"(
\"
"
},
{
"id"
:
8302
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"Hello"
},
{
"id"
:
332
,
"logprob"
:
-0.034698486
,
"special"
:
false
,
"text"
:
"
\"
"
},
{
"id"
:
494
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" +"
},
{
"id"
:
655
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" name"
},
{
"id"
:
494
,
"logprob"
:
-0.20141602
,
"special"
:
false
,
"text"
:
" +"
},
{
"id"
:
332
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\"
"
},
{
"id"
:
16013
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"!
\"
)"
},
{
"id"
:
222
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
222
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
610
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"def"
},
{
"id"
:
1489
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
100
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"_"
},
{
"id"
:
7670
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"hello"
},
{
"id"
:
100
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"_"
},
{
"id"
:
444
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"name"
},
{
"id"
:
100
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"_"
},
{
"id"
:
400
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"age"
},
{
"id"
:
45
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"("
},
{
"id"
:
444
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"name"
},
{
"id"
:
49
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
","
},
{
"id"
:
11505
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" age"
},
{
"id"
:
731
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"):"
},
{
"id"
:
303
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
1489
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
459
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"(
\"
"
},
{
"id"
:
8302
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"Hello"
},
{
"id"
:
332
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\"
"
},
{
"id"
:
494
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" +"
},
{
"id"
:
655
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" name"
},
{
"id"
:
494
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" +"
},
{
"id"
:
3021
,
"logprob"
:
-0.5761719
,
"special"
:
false
,
"text"
:
"
\"
,"
},
{
"id"
:
863
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" you"
},
{
"id"
:
904
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" are"
},
{
"id"
:
332
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"
\"
"
},
{
"id"
:
494
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" +"
},
{
"id"
:
615
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" str"
},
{
"id"
:
45
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"("
},
{
"id"
:
400
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
"age"
},
{
"id"
:
46
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
")"
}
],
"top_tokens"
:
null
},
"generated_text"
:
"():
\n
print(
\"
Hello World!
\"
)
\n\n
def print_hello_name(name):
\n
print(
\"
Hello
\"
+ name +
\"
!
\"
)
\n\n
def print_hello_name_age(name, age):
\n
print(
\"
Hello
\"
+ name +
\"
, you are
\"
+ str(age)"
}
integration-tests/models/__snapshots__/test_flash_starcoder2/test_flash_starcoder2_load.json
0 → 100644
View file @
b40e8334
[
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
610
,
"logprob"
:
null
,
"text"
:
"def"
},
{
"id"
:
1489
,
"logprob"
:
-5.2617188
,
"text"
:
" print"
},
{
"id"
:
100
,
"logprob"
:
-0.38476562
,
"text"
:
"_"
},
{
"id"
:
7670
,
"logprob"
:
-7.640625
,
"text"
:
"hello"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
2284
,
"logprob"
:
-0.92626953
,
"special"
:
false
,
"text"
:
"():"
},
{
"id"
:
303
,
"logprob"
:
-0.40722656
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
1489
,
"logprob"
:
-0.27954102
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
459
,
"logprob"
:
-0.6142578
,
"special"
:
false
,
"text"
:
"(
\"
"
},
{
"id"
:
8302
,
"logprob"
:
-0.68310547
,
"special"
:
false
,
"text"
:
"Hello"
},
{
"id"
:
10914
,
"logprob"
:
-1.4570312
,
"special"
:
false
,
"text"
:
" World"
},
{
"id"
:
16013
,
"logprob"
:
-0.80126953
,
"special"
:
false
,
"text"
:
"!
\"
)"
},
{
"id"
:
222
,
"logprob"
:
-0.6303711
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
222
,
"logprob"
:
-0.23327637
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
610
,
"logprob"
:
-1.2304688
,
"special"
:
false
,
"text"
:
"def"
}
],
"top_tokens"
:
null
},
"generated_text"
:
"():
\n
print(
\"
Hello World!
\"
)
\n\n
def"
},
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
610
,
"logprob"
:
null
,
"text"
:
"def"
},
{
"id"
:
1489
,
"logprob"
:
-5.2617188
,
"text"
:
" print"
},
{
"id"
:
100
,
"logprob"
:
-0.38476562
,
"text"
:
"_"
},
{
"id"
:
7670
,
"logprob"
:
-7.640625
,
"text"
:
"hello"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
2284
,
"logprob"
:
-0.92626953
,
"special"
:
false
,
"text"
:
"():"
},
{
"id"
:
303
,
"logprob"
:
-0.40722656
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
1489
,
"logprob"
:
-0.27954102
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
459
,
"logprob"
:
-0.6142578
,
"special"
:
false
,
"text"
:
"(
\"
"
},
{
"id"
:
8302
,
"logprob"
:
-0.68310547
,
"special"
:
false
,
"text"
:
"Hello"
},
{
"id"
:
10914
,
"logprob"
:
-1.4570312
,
"special"
:
false
,
"text"
:
" World"
},
{
"id"
:
16013
,
"logprob"
:
-0.80126953
,
"special"
:
false
,
"text"
:
"!
\"
)"
},
{
"id"
:
222
,
"logprob"
:
-0.6303711
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
222
,
"logprob"
:
-0.23327637
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
610
,
"logprob"
:
-1.2304688
,
"special"
:
false
,
"text"
:
"def"
}
],
"top_tokens"
:
null
},
"generated_text"
:
"():
\n
print(
\"
Hello World!
\"
)
\n\n
def"
},
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
610
,
"logprob"
:
null
,
"text"
:
"def"
},
{
"id"
:
1489
,
"logprob"
:
-5.2617188
,
"text"
:
" print"
},
{
"id"
:
100
,
"logprob"
:
-0.38476562
,
"text"
:
"_"
},
{
"id"
:
7670
,
"logprob"
:
-7.640625
,
"text"
:
"hello"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
2284
,
"logprob"
:
-0.92626953
,
"special"
:
false
,
"text"
:
"():"
},
{
"id"
:
303
,
"logprob"
:
-0.40722656
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
1489
,
"logprob"
:
-0.27954102
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
459
,
"logprob"
:
-0.6142578
,
"special"
:
false
,
"text"
:
"(
\"
"
},
{
"id"
:
8302
,
"logprob"
:
-0.68310547
,
"special"
:
false
,
"text"
:
"Hello"
},
{
"id"
:
10914
,
"logprob"
:
-1.4570312
,
"special"
:
false
,
"text"
:
" World"
},
{
"id"
:
16013
,
"logprob"
:
-0.80126953
,
"special"
:
false
,
"text"
:
"!
\"
)"
},
{
"id"
:
222
,
"logprob"
:
-0.6303711
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
222
,
"logprob"
:
-0.23327637
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
610
,
"logprob"
:
-1.2304688
,
"special"
:
false
,
"text"
:
"def"
}
],
"top_tokens"
:
null
},
"generated_text"
:
"():
\n
print(
\"
Hello World!
\"
)
\n\n
def"
},
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
610
,
"logprob"
:
null
,
"text"
:
"def"
},
{
"id"
:
1489
,
"logprob"
:
-5.2617188
,
"text"
:
" print"
},
{
"id"
:
100
,
"logprob"
:
-0.38476562
,
"text"
:
"_"
},
{
"id"
:
7670
,
"logprob"
:
-7.640625
,
"text"
:
"hello"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
2284
,
"logprob"
:
-0.92626953
,
"special"
:
false
,
"text"
:
"():"
},
{
"id"
:
303
,
"logprob"
:
-0.40722656
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
1489
,
"logprob"
:
-0.27954102
,
"special"
:
false
,
"text"
:
" print"
},
{
"id"
:
459
,
"logprob"
:
-0.6142578
,
"special"
:
false
,
"text"
:
"(
\"
"
},
{
"id"
:
8302
,
"logprob"
:
-0.68310547
,
"special"
:
false
,
"text"
:
"Hello"
},
{
"id"
:
10914
,
"logprob"
:
-1.4570312
,
"special"
:
false
,
"text"
:
" World"
},
{
"id"
:
16013
,
"logprob"
:
-0.80126953
,
"special"
:
false
,
"text"
:
"!
\"
)"
},
{
"id"
:
222
,
"logprob"
:
-0.6303711
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
222
,
"logprob"
:
-0.23327637
,
"special"
:
false
,
"text"
:
"
\n
"
},
{
"id"
:
610
,
"logprob"
:
-1.2304688
,
"special"
:
false
,
"text"
:
"def"
}
],
"top_tokens"
:
null
},
"generated_text"
:
"():
\n
print(
\"
Hello World!
\"
)
\n\n
def"
}
]
integration-tests/models/test_flash_starcoder2.py
0 → 100644
View file @
b40e8334
import
pytest
@
pytest
.
fixture
(
scope
=
"module"
)
def
flash_starcoder2_handle
(
launcher
):
with
launcher
(
"bigcode/starcoder2-3b"
,
num_shard
=
2
)
as
handle
:
yield
handle
@
pytest
.
fixture
(
scope
=
"module"
)
async
def
flash_starcoder2
(
flash_starcoder2_handle
):
await
flash_starcoder2_handle
.
health
(
300
)
return
flash_starcoder2_handle
.
client
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_starcoder2
(
flash_starcoder2
,
response_snapshot
):
response
=
await
flash_starcoder2
.
generate
(
"def print_hello"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
==
response_snapshot
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_starcoder2_default_params
(
flash_starcoder2
,
response_snapshot
):
response
=
await
flash_starcoder2
.
generate
(
"def print_hello"
,
max_new_tokens
=
60
,
temperature
=
0.2
,
top_p
=
0.95
,
decoder_input_details
=
True
,
seed
=
0
,
)
assert
response
.
details
.
generated_tokens
==
60
assert
response
==
response_snapshot
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_starcoder2_load
(
flash_starcoder2
,
generate_load
,
response_snapshot
):
responses
=
await
generate_load
(
flash_starcoder2
,
"def print_hello"
,
max_new_tokens
=
10
,
n
=
4
)
assert
len
(
responses
)
==
4
assert
all
([
r
.
generated_text
==
responses
[
0
].
generated_text
for
r
in
responses
])
assert
responses
==
response_snapshot
proto/generate.proto
View file @
b40e8334
...
...
@@ -230,7 +230,6 @@ message WarmupRequest {
uint32
max_total_tokens
=
4
;
}
/// Empty response
message
WarmupResponse
{
/// Maximum number of tokens supported by the model
optional
uint32
max_supported_total_tokens
=
1
;
...
...
server/text_generation_server/models/__init__.py
View file @
b40e8334
...
...
@@ -64,6 +64,7 @@ try:
from
text_generation_server.models.flash_mistral
import
FlashMistral
from
text_generation_server.models.flash_mixtral
import
FlashMixtral
from
text_generation_server.models.flash_phi
import
FlashPhi
from
text_generation_server.models.flash_starcoder2
import
FlashStarcoder2
from
text_generation_server.utils.flash_attn
import
HAS_FLASH_ATTN_V2_CUDA
except
ImportError
as
e
:
...
...
@@ -80,6 +81,7 @@ if FLASH_ATTENTION:
__all__
.
append
(
FlashMistral
)
__all__
.
append
(
FlashMixtral
)
__all__
.
append
(
FlashPhi
)
__all__
.
append
(
FlashStarcoder2
)
MAMBA_AVAILABLE
=
True
try
:
...
...
@@ -184,6 +186,16 @@ def get_model(
trust_remote_code
=
trust_remote_code
,
)
if
model_id
.
startswith
(
"facebook/galactica"
):
return
GalacticaSharded
(
model_id
,
revision
,
quantize
=
quantize
,
use_medusa
=
use_medusa
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
if
(
model_type
==
"gpt_bigcode"
or
model_type
==
"gpt2"
...
...
@@ -401,6 +413,18 @@ def get_model(
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
if
model_type
==
"starcoder2"
:
sliding_window
=
config_dict
.
get
(
"sliding_window"
,
-
1
)
if
(
(
sliding_window
is
None
or
sliding_window
==
-
1
)
and
FLASH_ATTENTION
)
or
HAS_FLASH_ATTN_V2_CUDA
:
return
FlashStarcoder2
(
model_id
,
revision
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
if
model_type
==
"opt"
:
return
OPTSharded
(
...
...
server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py
0 → 100644
View file @
b40e8334
# coding=utf-8
# Copyright 2024 Starcoder2 AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.distributed
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
typing
import
Optional
,
List
,
Tuple
from
text_generation_server.utils
import
paged_attention
,
flash_attn
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
TensorParallelEmbedding
,
PositionRotaryEmbedding
,
SpeculativeHead
,
get_linear
,
FastRMSNorm
,
FastLayerNorm
,
)
class
Starcoder2Config
(
PretrainedConfig
):
model_type
=
"starcoder2"
def
__init__
(
self
,
vocab_size
=
49152
,
hidden_size
=
3072
,
intermediate_size
=
12288
,
num_hidden_layers
=
30
,
num_attention_heads
=
24
,
num_key_value_heads
=
2
,
mlp_type
=
"default"
,
hidden_act
=
"gelu_pytorch_tanh"
,
max_position_embeddings
=
4096
,
initializer_range
=
0.018042
,
norm_type
=
"layer_norm"
,
norm_epsilon
=
1e-5
,
use_cache
=
True
,
bos_token_id
=
50256
,
eos_token_id
=
50256
,
rope_theta
=
10000.0
,
sliding_window
=
None
,
attention_dropout
=
0.0
,
residual_dropout
=
0.0
,
embedding_dropout
=
0.0
,
use_bias
:
bool
=
True
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
sliding_window
=
sliding_window
self
.
use_bias
=
use_bias
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
mlp_type
=
mlp_type
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
norm_type
=
norm_type
self
.
norm_epsilon
=
norm_epsilon
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
attention_dropout
=
attention_dropout
self
.
residual_dropout
=
residual_dropout
self
.
embedding_dropout
=
embedding_dropout
super
().
__init__
(
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
**
kwargs
,
)
def
load_attention
(
config
,
prefix
,
weights
):
if
config
.
num_attention_heads
!=
config
.
num_key_value_heads
:
return
_load_gqa
(
config
,
prefix
,
weights
)
else
:
return
TensorParallelColumnLinear
.
load_multi
(
config
,
prefixes
=
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
],
dim
=
0
,
weights
=
weights
,
bias
=
config
.
use_bias
,
)
def
_load_gqa
(
config
,
prefix
:
str
,
weights
):
assert
config
.
hidden_size
%
config
.
num_attention_heads
==
0
assert
config
.
num_attention_heads
%
weights
.
process_group
.
size
()
==
0
weight
=
weights
.
get_multi_weights_col
(
prefixes
=
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
],
quantize
=
config
.
quantize
,
dim
=
0
,
)
if
config
.
quantize
not
in
[
"gptq"
,
"awq"
]:
weight
=
weight
.
to
(
dtype
=
weights
.
dtype
).
to
(
device
=
weights
.
device
)
head_size
=
config
.
hidden_size
//
config
.
num_attention_heads
num_heads
=
config
.
num_attention_heads
//
weights
.
process_group
.
size
()
num_key_value_heads
=
config
.
num_key_value_heads
//
weights
.
process_group
.
size
()
assert
list
(
weight
.
shape
)
==
[
(
num_heads
+
2
*
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
,
],
f
"
{
list
(
weight
.
shape
)
}
!=
{
[(
num_heads
+
2
*
config
.
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
]
}
"
if
config
.
use_bias
:
w
=
[
weights
.
get_sharded
(
f
"
{
p
}
.bias"
,
dim
=
0
)
for
p
in
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
]
]
bias
=
torch
.
cat
(
w
,
dim
=
0
).
to
(
dtype
=
weights
.
dtype
).
to
(
device
=
weights
.
device
)
else
:
bias
=
None
return
TensorParallelColumnLinear
(
get_linear
(
weight
,
bias
=
bias
,
quantize
=
config
.
quantize
)
)
class
Starcoder2Attention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
prefix
:
str
,
config
,
weights
,
):
super
().
__init__
()
self
.
max_past
=
(
config
.
sliding_window
if
config
.
sliding_window
is
not
None
else
-
1
)
self
.
num_heads
=
config
.
num_attention_heads
self
.
hidden_size
=
config
.
hidden_size
self
.
head_size
=
self
.
hidden_size
//
self
.
num_heads
self
.
rotary_emb
=
PositionRotaryEmbedding
.
static
(
config
=
config
,
dim
=
self
.
head_size
,
base
=
config
.
rope_theta
,
device
=
weights
.
device
,
)
self
.
softmax_scale
=
self
.
head_size
**-
0.5
if
self
.
num_heads
%
weights
.
process_group
.
size
()
!=
0
:
raise
ValueError
(
f
"`num_heads` must be divisible by `num_shards` (got `num_heads`:
{
self
.
num_heads
}
"
f
"and `num_shards`:
{
weights
.
process_group
.
size
()
}
"
)
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
num_key_value_heads
=
(
config
.
num_key_value_heads
//
weights
.
process_group
.
size
()
)
self
.
query_key_value
=
load_attention
(
config
,
prefix
,
weights
)
self
.
o_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
weights
=
weights
,
bias
=
config
.
use_bias
,
)
self
.
num_groups
=
self
.
num_heads
//
self
.
num_key_value_heads
self
.
kv_head_mapping
=
torch
.
arange
(
0
,
self
.
num_key_value_heads
,
dtype
=
torch
.
int32
,
device
=
weights
.
device
).
repeat_interleave
(
self
.
num_groups
)
def
forward
(
self
,
hidden_states
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
prefill_cache_indices
,
):
qkv
=
self
.
query_key_value
(
hidden_states
)
query
,
kv
=
qkv
.
split
(
[
self
.
head_size
*
self
.
num_heads
,
2
*
self
.
head_size
*
self
.
num_key_value_heads
,
],
dim
=
1
,
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
kv
=
kv
.
view
(
-
1
,
2
,
self
.
num_key_value_heads
,
self
.
head_size
)
self
.
rotary_emb
(
query
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
cos
,
sin
)
if
prefill_cache_indices
is
not
None
:
kv_to_cache
=
kv
[
prefill_cache_indices
]
else
:
kv_to_cache
=
kv
paged_attention
.
reshape_and_cache
(
kv_to_cache
[:,
0
],
kv_to_cache
[:,
1
],
kv_cache
[
0
],
kv_cache
[
1
],
slots
)
# output tensor
attn_output
=
torch
.
empty_like
(
query
)
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# flash attention
flash_attn
.
attention
(
query
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
torch
.
select
(
kv
,
dim
=
1
,
index
=
1
),
attn_output
,
cu_seqlen_prefill
,
max_s
,
self
.
softmax_scale
,
window_size_left
=
self
.
max_past
,
)
# Decode
else
:
paged_attention
.
attention
(
attn_output
,
query
,
kv_cache
[
0
],
kv_cache
[
1
],
self
.
kv_head_mapping
,
self
.
softmax_scale
,
block_tables
,
input_lengths
,
max_s
,
)
return
self
.
o_proj
(
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
))
class
Starcoder2MLP
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
super
().
__init__
()
act
=
config
.
hidden_act
self
.
act
=
(
ACT2FN
[
act
]
if
"gelu"
not
in
act
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
(
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
"none"
),
)
)
# Fuse gate and up proj
self
.
c_fc
=
TensorParallelColumnLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.c_fc"
,
weights
=
weights
,
bias
=
config
.
use_bias
,
)
self
.
c_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.c_proj"
,
weights
=
weights
,
bias
=
config
.
use_bias
,
)
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
return
self
.
c_proj
(
hidden_states
)
class
Starcoder2GatedMLP
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
super
().
__init__
()
act
=
config
.
hidden_act
self
.
act
=
(
ACT2FN
[
act
]
if
"gelu"
not
in
act
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
(
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
"none"
),
)
)
# Fuse gate and up proj
self
.
gate_up_proj
=
TensorParallelColumnLinear
.
load_multi
(
config
,
prefixes
=
[
f
"
{
prefix
}
.gate_proj"
,
f
"
{
prefix
}
.up_proj"
],
weights
=
weights
,
dim
=
0
,
bias
=
config
.
use_bias
,
)
self
.
down_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
weights
=
weights
,
bias
=
config
.
use_bias
,
)
self
.
intermediate_size
=
(
config
.
intermediate_size
//
weights
.
process_group
.
size
()
)
def
forward
(
self
,
hidden_states
):
gate_up_states
=
self
.
gate_up_proj
(
hidden_states
)
gate_up_states
=
gate_up_states
.
view
(
-
1
,
2
,
self
.
intermediate_size
)
return
self
.
down_proj
(
self
.
act
(
gate_up_states
[:,
0
])
*
gate_up_states
[:,
1
])
STARCODER2_NORMALIZATION_CLASSES
=
{
"layer_norm"
:
FastLayerNorm
,
"rms_norm"
:
FastRMSNorm
,
}
STARCODER2_MLP_CLASSES
=
{
"default"
:
Starcoder2MLP
,
"gated"
:
Starcoder2GatedMLP
,
}
class
Starcoder2Layer
(
nn
.
Module
):
def
__init__
(
self
,
layer_id
,
config
,
weights
):
super
().
__init__
()
prefix
=
f
"model.layers.
{
layer_id
}
"
self
.
self_attn
=
Starcoder2Attention
(
prefix
=
f
"
{
prefix
}
.self_attn"
,
config
=
config
,
weights
=
weights
)
self
.
mlp
=
STARCODER2_MLP_CLASSES
[
config
.
mlp_type
](
prefix
=
f
"
{
prefix
}
.mlp"
,
config
=
config
,
weights
=
weights
)
self
.
input_layernorm
=
STARCODER2_NORMALIZATION_CLASSES
[
config
.
norm_type
].
load
(
prefix
=
f
"
{
prefix
}
.input_layernorm"
,
weights
=
weights
,
eps
=
config
.
norm_epsilon
)
self
.
post_attention_layernorm
=
STARCODER2_NORMALIZATION_CLASSES
[
config
.
norm_type
].
load
(
prefix
=
f
"
{
prefix
}
.post_attention_layernorm"
,
weights
=
weights
,
eps
=
config
.
norm_epsilon
,
)
def
forward
(
self
,
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
prefill_cache_indices
,
):
normed_hidden_states
,
res
=
self
.
input_layernorm
(
hidden_states
,
residual
)
# Self Attention
attn_output
=
self
.
self_attn
(
normed_hidden_states
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
prefill_cache_indices
,
)
# faster post attention rms norm
normed_attn_res_output
,
attn_res
=
self
.
post_attention_layernorm
(
attn_output
,
res
)
mlp_output
=
self
.
mlp
(
normed_attn_res_output
)
return
mlp_output
,
attn_res
class
Starcoder2Model
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
weights
):
super
().
__init__
()
process_group
=
weights
.
process_group
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
self
.
embed_tokens
=
TensorParallelEmbedding
(
prefix
=
"model.embed_tokens"
,
weights
=
weights
)
self
.
layers
=
nn
.
ModuleList
(
[
Starcoder2Layer
(
layer_id
,
config
,
weights
,
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
STARCODER2_NORMALIZATION_CLASSES
[
config
.
norm_type
].
load
(
prefix
=
"model.norm"
,
weights
=
weights
,
eps
=
config
.
norm_epsilon
)
self
.
gradient_checkpointing
=
False
self
.
head_size
=
self
.
layers
[
0
].
self_attn
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
self_attn
.
num_heads
self
.
num_key_value_heads
=
self
.
layers
[
0
].
self_attn
.
num_key_value_heads
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
],
kv_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
block_tables
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
true_max_s
:
int
,
prefill_cache_indices
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos
,
sin
=
self
.
layers
[
0
].
self_attn
.
rotary_emb
.
get_cos_sin
(
position_ids
,
true_max_s
,
hidden_states
.
dtype
)
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
[
i
],
block_tables
,
slots
,
input_lengths
,
max_s
,
prefill_cache_indices
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
FlashStarcoder2ForCausalLM
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
weights
):
super
().
__init__
()
self
.
model
=
Starcoder2Model
(
config
,
weights
)
try
:
self
.
lm_head
=
SpeculativeHead
.
load
(
config
,
prefix
=
"lm_head"
,
weights
=
weights
,
)
except
RuntimeError
:
self
.
lm_head
=
SpeculativeHead
.
load
(
config
,
prefix
=
"model.embed_tokens"
,
weights
=
weights
,
)
self
.
max_past
=
config
.
sliding_window
self
.
max_past_tensor
=
(
torch
.
tensor
(
config
.
sliding_window
,
device
=
weights
.
device
)
if
self
.
max_past
is
not
None
else
None
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
],
kv_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
block_tables
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
prefill_cache_indices
:
Optional
[
torch
.
Tensor
],
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
true_max_s
=
max_s
if
prefill_cache_indices
is
not
None
:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots
=
slots
[
prefill_cache_indices
]
elif
self
.
max_past
is
not
None
:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths
=
torch
.
clamp
(
input_lengths
,
max
=
self
.
max_past_tensor
)
hidden_states
=
self
.
model
(
input_ids
,
position_ids
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
true_max_s
,
prefill_cache_indices
,
)
if
lm_head_indices
is
not
None
:
hidden_states
=
hidden_states
[
lm_head_indices
]
logits
=
self
.
lm_head
(
hidden_states
)
return
logits
server/text_generation_server/models/flash_mistral.py
View file @
b40e8334
...
...
@@ -8,7 +8,7 @@ from dataclasses import dataclass
from
opentelemetry
import
trace
from
transformers
import
PreTrainedTokenizerBase
from
transformers.models.llama
import
LlamaTokenizerFast
from
typing
import
Optional
,
Tuple
,
Type
,
List
from
typing
import
Optional
,
Tuple
,
Type
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.models
import
FlashCausalLM
...
...
@@ -38,6 +38,19 @@ SLIDING_WINDOW_BLOCKS: Optional[int] = None
MEM_POOL
=
torch
.
cuda
.
graph_pool_handle
()
def
set_sliding_window
(
sliding_window
:
int
,
sliding_window_blocks
:
int
):
global
SLIDING_WINDOW
global
SLIDING_WINDOW_BLOCKS
SLIDING_WINDOW
=
sliding_window
SLIDING_WINDOW_BLOCKS
=
sliding_window_blocks
def
get_sliding_windows
()
->
Tuple
[
int
,
int
]:
global
SLIDING_WINDOW
global
SLIDING_WINDOW_BLOCKS
return
SLIDING_WINDOW
,
SLIDING_WINDOW_BLOCKS
# Adds windowing logic to FlashCausalLMBatch
@
dataclass
class
FlashMistralBatch
(
FlashCausalLMBatch
):
...
...
@@ -53,8 +66,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
)
->
"FlashCausalLMBatch"
:
global
SLIDING_WINDOW
global
SLIDING_WINDOW_BLOCKS
sliding_window
,
sliding_window_blocks
=
get_sliding_windows
()
batch_inputs
=
[]
max_truncation
=
0
...
...
@@ -139,8 +151,8 @@ class FlashMistralBatch(FlashCausalLMBatch):
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
needed_blocks
=
math
.
ceil
(
total_tokens
/
BLOCK_SIZE
)
if
SLIDING_WINDOW_BLOCKS
is
not
None
:
needed_blocks
=
min
(
needed_blocks
,
SLIDING_WINDOW_BLOCKS
)
if
sliding_window_blocks
is
not
None
:
needed_blocks
=
min
(
needed_blocks
,
sliding_window_blocks
)
blocks
+=
needed_blocks
needed_blocks_slots
.
append
((
needed_blocks
,
total_tokens
))
...
...
@@ -154,9 +166,9 @@ class FlashMistralBatch(FlashCausalLMBatch):
slot_indices
.
append
(
request_slot_indices
)
# Create tensor to slice into the kv tensor in prefill
if
SLIDING_WINDOW
is
not
None
:
if
sliding_window
is
not
None
:
request_prefill_cache_indices
=
torch
.
arange
(
cumulative_length
+
max
(
0
,
input_length
-
SLIDING_WINDOW
),
cumulative_length
+
max
(
0
,
input_length
-
sliding_window
),
cumulative_length
+
input_length
,
dtype
=
torch
.
int64
,
)
...
...
@@ -212,13 +224,13 @@ class FlashMistralBatch(FlashCausalLMBatch):
input_ids
=
np
.
concatenate
(
all_input_ids
,
dtype
=
np
.
int64
)
position_ids
=
torch
.
cat
(
position_ids
)
slot_indices
=
torch
.
cat
(
slot_indices
)
if
SLIDING_WINDOW
is
not
None
:
if
sliding_window
is
not
None
:
prefill_cache_indices
=
torch
.
cat
(
prefill_cache_indices
)
else
:
input_ids
=
all_input_ids
[
0
]
position_ids
=
position_ids
[
0
]
slot_indices
=
slot_indices
[
0
]
if
SLIDING_WINDOW
is
not
None
:
if
sliding_window
is
not
None
:
prefill_cache_indices
=
prefill_cache_indices
[
0
]
cu_seqlen_prefill
=
torch
.
tensor
(
...
...
@@ -228,7 +240,7 @@ class FlashMistralBatch(FlashCausalLMBatch):
position_ids
=
position_ids
.
to
(
device
)
slot_indices
=
slot_indices
.
to
(
device
)
prefill_cache_indices
=
(
prefill_cache_indices
.
to
(
device
)
if
SLIDING_WINDOW
is
not
None
else
None
prefill_cache_indices
.
to
(
device
)
if
sliding_window
is
not
None
else
None
)
input_ids
=
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
int64
,
device
=
device
)
input_lengths_tensor
=
torch
.
tensor
(
...
...
@@ -298,9 +310,6 @@ class BaseFlashMistral(FlashCausalLM):
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
):
global
SLIDING_WINDOW
global
SLIDING_WINDOW_BLOCKS
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
...
...
@@ -324,8 +333,9 @@ class BaseFlashMistral(FlashCausalLM):
# Set context windows
if
config
.
sliding_window
is
not
None
:
SLIDING_WINDOW
=
config
.
sliding_window
SLIDING_WINDOW_BLOCKS
=
math
.
ceil
(
config
.
sliding_window
/
BLOCK_SIZE
)
set_sliding_window
(
config
.
sliding_window
,
math
.
ceil
(
config
.
sliding_window
/
BLOCK_SIZE
)
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
...
...
server/text_generation_server/models/flash_starcoder2.py
0 → 100644
View file @
b40e8334
import
math
import
torch
from
typing
import
Optional
from
transformers.models.gpt2
import
GPT2TokenizerFast
from
text_generation_server.models.cache_manager
import
BLOCK_SIZE
from
text_generation_server.models.flash_mistral
import
(
BaseFlashMistral
,
set_sliding_window
,
)
from
text_generation_server.models.custom_modeling.flash_starcoder2_modeling
import
(
Starcoder2Config
,
FlashStarcoder2ForCausalLM
,
)
from
text_generation_server.utils
import
(
initialize_torch_distributed
,
weight_files
,
Weights
,
)
# Starcoder2 has the same base as Mistral
class
FlashStarcoder2
(
BaseFlashMistral
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
use_medusa
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
float16
if
dtype
is
None
else
dtype
else
:
raise
NotImplementedError
(
"FlashLlama is only available on GPU"
)
tokenizer
=
GPT2TokenizerFast
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
,
trust_remote_code
=
trust_remote_code
,
)
config
=
Starcoder2Config
.
from_pretrained
(
model_id
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
)
config
.
quantize
=
quantize
config
.
use_medusa
=
use_medusa
# Set context windows
if
config
.
sliding_window
is
not
None
:
set_sliding_window
(
config
.
sliding_window
,
math
.
ceil
(
config
.
sliding_window
/
BLOCK_SIZE
)
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
weights
=
Weights
(
filenames
,
device
,
dtype
,
process_group
=
self
.
process_group
)
if
config
.
quantize
in
[
"gptq"
,
"awq"
]:
weights
.
_set_gptq_params
(
model_id
,
revision
)
model
=
FlashStarcoder2ForCausalLM
(
config
,
weights
)
self
.
cuda_graphs
=
{}
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
BaseFlashMistral
,
self
).
__init__
(
model
=
model
,
tokenizer
=
tokenizer
,
num_layers
=
len
(
model
.
model
.
layers
),
num_kv_heads
=
model
.
model
.
num_key_value_heads
,
head_size
=
model
.
model
.
head_size
,
dtype
=
dtype
,
device
=
device
,
rank
=
rank
,
world_size
=
world_size
,
sliding_window
=
config
.
sliding_window
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment