Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
32cb1960
Unverified
Commit
32cb1960
authored
May 29, 2024
by
Michael Yang
Committed by
GitHub
May 29, 2024
Browse files
Merge pull request #4380 from ollama/mxyng/tokenize
use tokenize/detokenize
parents
646371f5
de781b37
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
66 additions
and
242 deletions
+66
-242
llm/ext_server/server.cpp
llm/ext_server/server.cpp
+9
-141
llm/llm.go
llm/llm.go
+45
-0
llm/server.go
llm/server.go
+12
-101
No files found.
llm/ext_server/server.cpp
View file @
32cb1960
...
@@ -140,7 +140,6 @@ struct server_slot {
...
@@ -140,7 +140,6 @@ struct server_slot {
std
::
vector
<
llama_token
>
cache_tokens
;
std
::
vector
<
llama_token
>
cache_tokens
;
std
::
vector
<
completion_token_output
>
generated_token_probs
;
std
::
vector
<
completion_token_output
>
generated_token_probs
;
bool
infill
=
false
;
bool
embedding
=
false
;
bool
embedding
=
false
;
bool
has_next_token
=
true
;
bool
has_next_token
=
true
;
bool
truncated
=
false
;
bool
truncated
=
false
;
...
@@ -187,7 +186,6 @@ struct server_slot {
...
@@ -187,7 +186,6 @@ struct server_slot {
n_past
=
0
;
n_past
=
0
;
n_sent_text
=
0
;
n_sent_text
=
0
;
n_sent_token_probs
=
0
;
n_sent_token_probs
=
0
;
infill
=
false
;
ga_i
=
0
;
ga_i
=
0
;
n_past_se
=
0
;
n_past_se
=
0
;
...
@@ -600,16 +598,6 @@ struct llama_server_context
...
@@ -600,16 +598,6 @@ struct llama_server_context
slot
->
params
.
n_predict
=
slot
->
n_predict
;
slot
->
params
.
n_predict
=
slot
->
n_predict
;
}
}
// infill
if
(
data
.
count
(
"input_prefix"
)
!=
0
)
{
slot
->
params
.
input_prefix
=
data
[
"input_prefix"
];
}
else
{
slot
->
params
.
input_prefix
=
""
;
}
if
(
data
.
count
(
"input_suffix"
)
!=
0
)
if
(
data
.
count
(
"input_suffix"
)
!=
0
)
{
{
slot
->
params
.
input_suffix
=
data
[
"input_suffix"
];
slot
->
params
.
input_suffix
=
data
[
"input_suffix"
];
...
@@ -897,15 +885,6 @@ struct llama_server_context
...
@@ -897,15 +885,6 @@ struct llama_server_context
system_need_update
=
true
;
system_need_update
=
true
;
}
}
void
system_prompt_process
(
const
json
&
sys_props
)
{
system_prompt
=
sys_props
.
value
(
"prompt"
,
""
);
name_user
=
sys_props
.
value
(
"anti_prompt"
,
""
);
name_assistant
=
sys_props
.
value
(
"assistant_name"
,
""
);
system_prompt_notify
();
}
static
size_t
find_stopping_strings
(
const
std
::
string
&
text
,
const
size_t
last_token_size
,
static
size_t
find_stopping_strings
(
const
std
::
string
&
text
,
const
size_t
last_token_size
,
const
stop_type
type
,
server_slot
&
slot
)
const
stop_type
type
,
server_slot
&
slot
)
{
{
...
@@ -1263,13 +1242,12 @@ struct llama_server_context
...
@@ -1263,13 +1242,12 @@ struct llama_server_context
queue_results
.
send
(
res
);
queue_results
.
send
(
res
);
}
}
void
request_completion
(
int
task_id
,
json
data
,
bool
infill
,
bool
embedding
,
int
multitask_id
)
void
request_completion
(
int
task_id
,
json
data
,
bool
embedding
,
int
multitask_id
)
{
{
task_server
task
;
task_server
task
;
task
.
id
=
task_id
;
task
.
id
=
task_id
;
task
.
target_id
=
0
;
task
.
target_id
=
0
;
task
.
data
=
std
::
move
(
data
);
task
.
data
=
std
::
move
(
data
);
task
.
infill_mode
=
infill
;
task
.
embedding_mode
=
embedding
;
task
.
embedding_mode
=
embedding
;
task
.
type
=
TASK_TYPE_COMPLETION
;
task
.
type
=
TASK_TYPE_COMPLETION
;
task
.
multitask_id
=
multitask_id
;
task
.
multitask_id
=
multitask_id
;
...
@@ -1415,8 +1393,8 @@ struct llama_server_context
...
@@ -1415,8 +1393,8 @@ struct llama_server_context
json
subtask_data
=
multiprompt_task
.
data
;
json
subtask_data
=
multiprompt_task
.
data
;
subtask_data
[
"prompt"
]
=
subtask_data
[
"prompt"
][
i
];
subtask_data
[
"prompt"
]
=
subtask_data
[
"prompt"
][
i
];
// subtasks inherit everything else (
infill mode,
embedding mode, etc.)
// subtasks inherit everything else (embedding mode, etc.)
request_completion
(
subtask_ids
[
i
],
subtask_data
,
multiprompt_task
.
infill_mode
,
multiprompt_task
.
embedding_mode
,
multitask_id
);
request_completion
(
subtask_ids
[
i
],
subtask_data
,
multiprompt_task
.
embedding_mode
,
multitask_id
);
}
}
}
}
...
@@ -1434,26 +1412,8 @@ struct llama_server_context
...
@@ -1434,26 +1412,8 @@ struct llama_server_context
break
;
break
;
}
}
if
(
task
.
data
.
contains
(
"system_prompt"
))
{
if
(
!
all_slots_are_idle
)
{
send_error
(
task
,
"system prompt can only be updated when all slots are idle"
);
break
;
}
system_prompt_process
(
task
.
data
[
"system_prompt"
]);
// reset cache_tokens for all slots
for
(
server_slot
&
slot
:
slots
)
{
slot
.
cache_tokens
.
clear
();
slot
.
n_past
=
0
;
slot
.
n_past_se
=
0
;
}
}
slot
->
reset
();
slot
->
reset
();
slot
->
infill
=
task
.
infill_mode
;
slot
->
embedding
=
task
.
embedding_mode
;
slot
->
embedding
=
task
.
embedding_mode
;
slot
->
task_id
=
task
.
id
;
slot
->
task_id
=
task
.
id
;
slot
->
multitask_id
=
task
.
multitask_id
;
slot
->
multitask_id
=
task
.
multitask_id
;
...
@@ -1679,8 +1639,7 @@ struct llama_server_context
...
@@ -1679,8 +1639,7 @@ struct llama_server_context
const
bool
has_prompt
=
slot
.
prompt
.
is_array
()
||
(
slot
.
prompt
.
is_string
()
&&
!
slot
.
prompt
.
get
<
std
::
string
>
().
empty
())
||
!
slot
.
images
.
empty
();
const
bool
has_prompt
=
slot
.
prompt
.
is_array
()
||
(
slot
.
prompt
.
is_string
()
&&
!
slot
.
prompt
.
get
<
std
::
string
>
().
empty
())
||
!
slot
.
images
.
empty
();
// empty prompt passed -> release the slot and send empty response
// empty prompt passed -> release the slot and send empty response
// note: infill mode allows empty prompt
if
(
slot
.
state
==
IDLE
&&
slot
.
command
==
LOAD_PROMPT
&&
!
has_prompt
)
if
(
slot
.
state
==
IDLE
&&
slot
.
command
==
LOAD_PROMPT
&&
!
has_prompt
&&
!
slot
.
infill
)
{
{
slot
.
release
();
slot
.
release
();
slot
.
print_timings
();
slot
.
print_timings
();
...
@@ -1697,33 +1656,7 @@ struct llama_server_context
...
@@ -1697,33 +1656,7 @@ struct llama_server_context
slot
.
t_start_process_prompt
=
ggml_time_us
();
slot
.
t_start_process_prompt
=
ggml_time_us
();
slot
.
t_start_genereration
=
0
;
slot
.
t_start_genereration
=
0
;
if
(
slot
.
infill
)
{
bool
suff_rm_leading_spc
=
true
;
if
(
params
.
input_suffix
.
find_first_of
(
' '
)
==
0
&&
params
.
input_suffix
.
size
()
>
1
)
{
params
.
input_suffix
.
erase
(
0
,
1
);
suff_rm_leading_spc
=
false
;
}
auto
prefix_tokens
=
tokenize
(
slot
.
params
.
input_prefix
,
false
);
auto
suffix_tokens
=
tokenize
(
slot
.
params
.
input_suffix
,
false
);
const
int
space_token
=
29871
;
// TODO: this should not be hardcoded
if
(
suff_rm_leading_spc
&&
!
suffix_tokens
.
empty
()
&&
suffix_tokens
[
0
]
==
space_token
)
{
suffix_tokens
.
erase
(
suffix_tokens
.
begin
());
}
prefix_tokens
.
insert
(
prefix_tokens
.
begin
(),
llama_token_prefix
(
model
));
prefix_tokens
.
insert
(
prefix_tokens
.
begin
(),
llama_token_bos
(
model
));
// always add BOS
prefix_tokens
.
insert
(
prefix_tokens
.
end
(),
llama_token_suffix
(
model
));
prefix_tokens
.
insert
(
prefix_tokens
.
end
(),
suffix_tokens
.
begin
(),
suffix_tokens
.
end
());
prefix_tokens
.
push_back
(
llama_token_middle
(
model
));
prompt_tokens
=
prefix_tokens
;
}
else
{
prompt_tokens
=
tokenize
(
slot
.
prompt
,
system_prompt
.
empty
()
&&
add_bos_token
);
// add BOS if there isn't system prompt
prompt_tokens
=
tokenize
(
slot
.
prompt
,
system_prompt
.
empty
()
&&
add_bos_token
);
// add BOS if there isn't system prompt
}
slot
.
n_prompt_tokens
=
prompt_tokens
.
size
();
slot
.
n_prompt_tokens
=
prompt_tokens
.
size
();
...
@@ -2130,8 +2063,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
...
@@ -2130,8 +2063,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
printf
(
"
\n
"
);
printf
(
"
\n
"
);
}
}
static
void
server_params_parse
(
int
argc
,
char
**
argv
,
server_params
&
sparams
,
static
void
server_params_parse
(
int
argc
,
char
**
argv
,
server_params
&
sparams
,
gpt_params
&
params
)
gpt_params
&
params
,
llama_server_context
&
llama
)
{
{
gpt_params
default_params
;
gpt_params
default_params
;
server_params
default_sparams
;
server_params
default_sparams
;
...
@@ -2546,27 +2478,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
...
@@ -2546,27 +2478,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
}
params
.
n_predict
=
std
::
stoi
(
argv
[
i
]);
params
.
n_predict
=
std
::
stoi
(
argv
[
i
]);
}
}
else
if
(
arg
==
"-spf"
||
arg
==
"--system-prompt-file"
)
{
if
(
++
i
>=
argc
)
{
invalid_param
=
true
;
break
;
}
std
::
ifstream
file
(
argv
[
i
]);
if
(
!
file
)
{
fprintf
(
stderr
,
"error: failed to open file '%s'
\n
"
,
argv
[
i
]);
invalid_param
=
true
;
break
;
}
std
::
string
systm_content
;
std
::
copy
(
std
::
istreambuf_iterator
<
char
>
(
file
),
std
::
istreambuf_iterator
<
char
>
(),
std
::
back_inserter
(
systm_content
)
);
llama
.
system_prompt_process
(
json
::
parse
(
systm_content
));
}
else
if
(
arg
==
"-ctk"
||
arg
==
"--cache-type-k"
)
{
else
if
(
arg
==
"-ctk"
||
arg
==
"--cache-type-k"
)
{
params
.
cache_type_k
=
argv
[
++
i
];
params
.
cache_type_k
=
argv
[
++
i
];
}
}
...
@@ -2714,21 +2625,6 @@ static json format_partial_response(
...
@@ -2714,21 +2625,6 @@ static json format_partial_response(
return
res
;
return
res
;
}
}
static
json
format_tokenizer_response
(
const
std
::
vector
<
llama_token
>
&
tokens
)
{
return
json
{
{
"tokens"
,
tokens
}
};
}
static
json
format_detokenized_response
(
std
::
string
content
)
{
return
json
{
{
"content"
,
content
}
};
}
static
void
log_server_request
(
const
httplib
::
Request
&
req
,
const
httplib
::
Response
&
res
)
static
void
log_server_request
(
const
httplib
::
Request
&
req
,
const
httplib
::
Response
&
res
)
{
{
// skip GH copilot requests when using default port
// skip GH copilot requests when using default port
...
@@ -2818,7 +2714,7 @@ int main(int argc, char **argv) {
...
@@ -2818,7 +2714,7 @@ int main(int argc, char **argv) {
// struct that contains llama context and inference
// struct that contains llama context and inference
llama_server_context
llama
;
llama_server_context
llama
;
server_params_parse
(
argc
,
argv
,
sparams
,
params
,
llama
);
server_params_parse
(
argc
,
argv
,
sparams
,
params
);
if
(
params
.
model_alias
==
"unknown"
)
if
(
params
.
model_alias
==
"unknown"
)
{
{
...
@@ -3150,7 +3046,7 @@ int main(int argc, char **argv) {
...
@@ -3150,7 +3046,7 @@ int main(int argc, char **argv) {
json
data
=
json
::
parse
(
req
.
body
);
json
data
=
json
::
parse
(
req
.
body
);
const
int
task_id
=
llama
.
queue_tasks
.
get_new_id
();
const
int
task_id
=
llama
.
queue_tasks
.
get_new_id
();
llama
.
queue_results
.
add_waiting_task_id
(
task_id
);
llama
.
queue_results
.
add_waiting_task_id
(
task_id
);
llama
.
request_completion
(
task_id
,
data
,
false
,
false
,
-
1
);
llama
.
request_completion
(
task_id
,
data
,
false
,
-
1
);
if
(
!
json_value
(
data
,
"stream"
,
false
))
{
if
(
!
json_value
(
data
,
"stream"
,
false
))
{
std
::
string
completion_text
;
std
::
string
completion_text
;
task_result
result
=
llama
.
queue_results
.
recv
(
task_id
);
task_result
result
=
llama
.
queue_results
.
recv
(
task_id
);
...
@@ -3218,34 +3114,6 @@ int main(int argc, char **argv) {
...
@@ -3218,34 +3114,6 @@ int main(int argc, char **argv) {
}
}
});
});
svr
.
Post
(
"/tokenize"
,
[
&
llama
](
const
httplib
::
Request
&
req
,
httplib
::
Response
&
res
)
{
res
.
set_header
(
"Access-Control-Allow-Origin"
,
req
.
get_header_value
(
"Origin"
));
const
json
body
=
json
::
parse
(
req
.
body
);
std
::
vector
<
llama_token
>
tokens
;
if
(
body
.
count
(
"content"
)
!=
0
)
{
tokens
=
llama
.
tokenize
(
body
[
"content"
],
false
);
}
const
json
data
=
format_tokenizer_response
(
tokens
);
return
res
.
set_content
(
data
.
dump
(),
"application/json; charset=utf-8"
);
});
svr
.
Post
(
"/detokenize"
,
[
&
llama
](
const
httplib
::
Request
&
req
,
httplib
::
Response
&
res
)
{
res
.
set_header
(
"Access-Control-Allow-Origin"
,
req
.
get_header_value
(
"Origin"
));
const
json
body
=
json
::
parse
(
req
.
body
);
std
::
string
content
;
if
(
body
.
count
(
"tokens"
)
!=
0
)
{
const
std
::
vector
<
llama_token
>
tokens
=
body
[
"tokens"
];
content
=
tokens_to_str
(
llama
.
ctx
,
tokens
.
cbegin
(),
tokens
.
cend
());
}
const
json
data
=
format_detokenized_response
(
content
);
return
res
.
set_content
(
data
.
dump
(),
"application/json; charset=utf-8"
);
});
svr
.
Post
(
"/embedding"
,
[
&
llama
](
const
httplib
::
Request
&
req
,
httplib
::
Response
&
res
)
svr
.
Post
(
"/embedding"
,
[
&
llama
](
const
httplib
::
Request
&
req
,
httplib
::
Response
&
res
)
{
{
res
.
set_header
(
"Access-Control-Allow-Origin"
,
req
.
get_header_value
(
"Origin"
));
res
.
set_header
(
"Access-Control-Allow-Origin"
,
req
.
get_header_value
(
"Origin"
));
...
@@ -3272,7 +3140,7 @@ int main(int argc, char **argv) {
...
@@ -3272,7 +3140,7 @@ int main(int argc, char **argv) {
// create and queue the task
// create and queue the task
const
int
task_id
=
llama
.
queue_tasks
.
get_new_id
();
const
int
task_id
=
llama
.
queue_tasks
.
get_new_id
();
llama
.
queue_results
.
add_waiting_task_id
(
task_id
);
llama
.
queue_results
.
add_waiting_task_id
(
task_id
);
llama
.
request_completion
(
task_id
,
{
{
"prompt"
,
prompt
},
{
"n_predict"
,
0
},
{
"image_data"
,
image_data
}
},
false
,
true
,
-
1
);
llama
.
request_completion
(
task_id
,
{
{
"prompt"
,
prompt
},
{
"n_predict"
,
0
},
{
"image_data"
,
image_data
}
},
true
,
-
1
);
// get the result
// get the result
task_result
result
=
llama
.
queue_results
.
recv
(
task_id
);
task_result
result
=
llama
.
queue_results
.
recv
(
task_id
);
...
...
llm/llm.go
View file @
32cb1960
...
@@ -12,6 +12,7 @@ package llm
...
@@ -12,6 +12,7 @@ package llm
import
"C"
import
"C"
import
(
import
(
"fmt"
"fmt"
"strings"
"unsafe"
"unsafe"
)
)
...
@@ -37,3 +38,47 @@ func Quantize(infile, outfile string, ftype fileType) error {
...
@@ -37,3 +38,47 @@ func Quantize(infile, outfile string, ftype fileType) error {
return
nil
return
nil
}
}
type
llamaModel
struct
{
m
*
C
.
struct_llama_model
}
func
newLlamaModel
(
p
string
)
*
llamaModel
{
cs
:=
C
.
CString
(
p
)
defer
C
.
free
(
unsafe
.
Pointer
(
cs
))
return
&
llamaModel
{
C
.
llama_load_model_from_file
(
cs
,
C
.
llama_model_default_params
(),
),
}
}
func
(
llm
*
llamaModel
)
Close
()
{
C
.
llama_free_model
(
llm
.
m
)
}
func
(
llm
*
llamaModel
)
Tokenize
(
s
string
)
[]
int
{
cs
:=
C
.
CString
(
s
)
defer
C
.
free
(
unsafe
.
Pointer
(
cs
))
tokens
:=
make
([]
int
,
len
(
s
)
+
2
)
if
n
:=
C
.
llama_tokenize
(
llm
.
m
,
cs
,
C
.
int
(
len
(
s
)),
(
*
C
.
llama_token
)(
unsafe
.
Pointer
(
&
tokens
[
0
])),
C
.
int
(
len
(
s
)
+
2
),
false
,
true
);
n
>
0
{
return
tokens
[
:
n
]
}
return
nil
}
func
(
llm
*
llamaModel
)
Detokenize
(
i32s
[]
int
)
string
{
var
sb
strings
.
Builder
for
_
,
i32
:=
range
i32s
{
c
:=
make
([]
byte
,
512
)
if
n
:=
C
.
llama_token_to_piece
(
llm
.
m
,
C
.
llama_token
(
i32
),
(
*
C
.
char
)(
unsafe
.
Pointer
(
&
c
[
0
])),
C
.
int
(
len
(
c
)),
false
);
n
>
0
{
sb
.
WriteString
(
unsafe
.
String
(
&
c
[
0
],
n
))
}
}
return
sb
.
String
()
}
llm/server.go
View file @
32cb1960
...
@@ -57,6 +57,8 @@ type llmServer struct {
...
@@ -57,6 +57,8 @@ type llmServer struct {
loadDuration
time
.
Duration
// Record how long it took the model to load
loadDuration
time
.
Duration
// Record how long it took the model to load
loadProgress
float32
loadProgress
float32
*
llamaModel
sem
*
semaphore
.
Weighted
sem
*
semaphore
.
Weighted
}
}
...
@@ -306,6 +308,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
...
@@ -306,6 +308,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
totalLayers
:
ggml
.
KV
()
.
BlockCount
()
+
1
,
totalLayers
:
ggml
.
KV
()
.
BlockCount
()
+
1
,
gpuCount
:
gpuCount
,
gpuCount
:
gpuCount
,
done
:
make
(
chan
error
,
1
),
done
:
make
(
chan
error
,
1
),
llamaModel
:
newLlamaModel
(
model
),
}
}
s
.
cmd
.
Env
=
os
.
Environ
()
s
.
cmd
.
Env
=
os
.
Environ
()
...
@@ -843,12 +846,12 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
...
@@ -843,12 +846,12 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return
nil
,
fmt
.
Errorf
(
"unexpected server status: %s"
,
status
.
ToString
())
return
nil
,
fmt
.
Errorf
(
"unexpected server status: %s"
,
status
.
ToString
())
}
}
data
,
err
:=
json
.
Marshal
(
TokenizeRequest
{
Content
:
prompt
})
var
b
bytes
.
Buffer
if
err
!=
nil
{
if
err
:=
json
.
NewEncoder
(
&
b
)
.
Encode
(
EmbeddingRequest
{
Content
:
prompt
});
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"error marshaling embed data: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"error marshaling embed data: %w"
,
err
)
}
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
fmt
.
Sprintf
(
"http://127.0.0.1:%d/embedding"
,
s
.
port
),
b
ytes
.
NewBuffer
(
data
)
)
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
fmt
.
Sprintf
(
"http://127.0.0.1:%d/embedding"
,
s
.
port
),
&
b
)
if
err
!=
nil
{
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"error creating embed request: %w"
,
err
)
return
nil
,
fmt
.
Errorf
(
"error creating embed request: %w"
,
err
)
}
}
...
@@ -878,108 +881,12 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
...
@@ -878,108 +881,12 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return
embedding
.
Embedding
,
nil
return
embedding
.
Embedding
,
nil
}
}
type
TokenizeRequest
struct
{
Content
string
`json:"content"`
}
type
TokenizeResponse
struct
{
Tokens
[]
int
`json:"tokens"`
}
func
(
s
*
llmServer
)
Tokenize
(
ctx
context
.
Context
,
content
string
)
([]
int
,
error
)
{
func
(
s
*
llmServer
)
Tokenize
(
ctx
context
.
Context
,
content
string
)
([]
int
,
error
)
{
// Make sure the server is ready
return
s
.
llamaModel
.
Tokenize
(
content
),
nil
status
,
err
:=
s
.
getServerStatus
(
ctx
)
if
err
!=
nil
{
return
nil
,
err
}
else
if
status
!=
ServerStatusReady
&&
status
!=
ServerStatusNoSlotsAvailable
{
return
nil
,
fmt
.
Errorf
(
"unexpected server status: %s"
,
status
.
ToString
())
}
data
,
err
:=
json
.
Marshal
(
TokenizeRequest
{
Content
:
content
})
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"marshaling encode data: %w"
,
err
)
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
fmt
.
Sprintf
(
"http://127.0.0.1:%d/tokenize"
,
s
.
port
),
bytes
.
NewBuffer
(
data
))
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"encode request: %w"
,
err
)
}
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
resp
,
err
:=
http
.
DefaultClient
.
Do
(
req
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"do encode request: %w"
,
err
)
}
defer
resp
.
Body
.
Close
()
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"read encode request: %w"
,
err
)
}
if
resp
.
StatusCode
>=
400
{
log
.
Printf
(
"llm encode error: %s"
,
body
)
return
nil
,
fmt
.
Errorf
(
"%s"
,
body
)
}
var
encoded
TokenizeResponse
if
err
:=
json
.
Unmarshal
(
body
,
&
encoded
);
err
!=
nil
{
return
nil
,
fmt
.
Errorf
(
"unmarshal encode response: %w"
,
err
)
}
return
encoded
.
Tokens
,
nil
}
type
DetokenizeRequest
struct
{
Tokens
[]
int
`json:"tokens"`
}
type
DetokenizeResponse
struct
{
Content
string
`json:"content"`
}
}
func
(
s
*
llmServer
)
Detokenize
(
ctx
context
.
Context
,
tokens
[]
int
)
(
string
,
error
)
{
func
(
s
*
llmServer
)
Detokenize
(
ctx
context
.
Context
,
tokens
[]
int
)
(
string
,
error
)
{
// Make sure the server is ready
return
s
.
llamaModel
.
Detokenize
(
tokens
),
nil
status
,
err
:=
s
.
getServerStatus
(
ctx
)
if
err
!=
nil
{
return
""
,
err
}
else
if
status
!=
ServerStatusReady
&&
status
!=
ServerStatusNoSlotsAvailable
{
return
""
,
fmt
.
Errorf
(
"unexpected server status: %s"
,
status
.
ToString
())
}
data
,
err
:=
json
.
Marshal
(
DetokenizeRequest
{
Tokens
:
tokens
})
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"marshaling decode data: %w"
,
err
)
}
req
,
err
:=
http
.
NewRequestWithContext
(
ctx
,
http
.
MethodPost
,
fmt
.
Sprintf
(
"http://127.0.0.1:%d/detokenize"
,
s
.
port
),
bytes
.
NewBuffer
(
data
))
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"decode request: %w"
,
err
)
}
req
.
Header
.
Set
(
"Content-Type"
,
"application/json"
)
resp
,
err
:=
http
.
DefaultClient
.
Do
(
req
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"do decode request: %w"
,
err
)
}
defer
resp
.
Body
.
Close
()
body
,
err
:=
io
.
ReadAll
(
resp
.
Body
)
if
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"read decode request: %w"
,
err
)
}
if
resp
.
StatusCode
>=
400
{
log
.
Printf
(
"llm decode error: %s"
,
body
)
return
""
,
fmt
.
Errorf
(
"%s"
,
body
)
}
var
decoded
DetokenizeResponse
if
err
:=
json
.
Unmarshal
(
body
,
&
decoded
);
err
!=
nil
{
return
""
,
fmt
.
Errorf
(
"unmarshal encode response: %w"
,
err
)
}
return
decoded
.
Content
,
nil
}
}
func
(
s
*
llmServer
)
Close
()
error
{
func
(
s
*
llmServer
)
Close
()
error
{
...
@@ -997,6 +904,10 @@ func (s *llmServer) Close() error {
...
@@ -997,6 +904,10 @@ func (s *llmServer) Close() error {
slog
.
Debug
(
"llama server stopped"
)
slog
.
Debug
(
"llama server stopped"
)
}
}
if
s
.
llamaModel
!=
nil
{
s
.
llamaModel
.
Close
()
}
return
nil
return
nil
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment