Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Conformer_pytorch
Commits
764b3a75
Commit
764b3a75
authored
Jun 07, 2023
by
Sugon_ldc
Browse files
add new model
parents
Changes
498
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4129 additions
and
0 deletions
+4129
-0
runtime/core/http/CMakeLists.txt
runtime/core/http/CMakeLists.txt
+5
-0
runtime/core/http/http_client.cc
runtime/core/http/http_client.cc
+64
-0
runtime/core/http/http_client.h
runtime/core/http/http_client.h
+66
-0
runtime/core/http/http_server.cc
runtime/core/http/http_server.cc
+193
-0
runtime/core/http/http_server.h
runtime/core/http/http_server.h
+102
-0
runtime/core/kaldi/CMakeLists.txt
runtime/core/kaldi/CMakeLists.txt
+54
-0
runtime/core/kaldi/README.md
runtime/core/kaldi/README.md
+21
-0
runtime/core/kaldi/base/io-funcs-inl.h
runtime/core/kaldi/base/io-funcs-inl.h
+329
-0
runtime/core/kaldi/base/io-funcs.cc
runtime/core/kaldi/base/io-funcs.cc
+215
-0
runtime/core/kaldi/base/io-funcs.h
runtime/core/kaldi/base/io-funcs.h
+246
-0
runtime/core/kaldi/base/kaldi-common.h
runtime/core/kaldi/base/kaldi-common.h
+41
-0
runtime/core/kaldi/base/kaldi-error.cc
runtime/core/kaldi/base/kaldi-error.cc
+42
-0
runtime/core/kaldi/base/kaldi-error.h
runtime/core/kaldi/base/kaldi-error.h
+57
-0
runtime/core/kaldi/base/kaldi-math.cc
runtime/core/kaldi/base/kaldi-math.cc
+164
-0
runtime/core/kaldi/base/kaldi-math.h
runtime/core/kaldi/base/kaldi-math.h
+363
-0
runtime/core/kaldi/base/kaldi-types.h
runtime/core/kaldi/base/kaldi-types.h
+75
-0
runtime/core/kaldi/base/kaldi-utils.h
runtime/core/kaldi/base/kaldi-utils.h
+155
-0
runtime/core/kaldi/decoder/lattice-faster-decoder.cc
runtime/core/kaldi/decoder/lattice-faster-decoder.cc
+1101
-0
runtime/core/kaldi/decoder/lattice-faster-decoder.h
runtime/core/kaldi/decoder/lattice-faster-decoder.h
+558
-0
runtime/core/kaldi/decoder/lattice-faster-online-decoder.cc
runtime/core/kaldi/decoder/lattice-faster-online-decoder.cc
+278
-0
No files found.
Too many changes to show.
To preserve performance only
498 of 498+
files are displayed.
Plain diff
Email patch
runtime/core/http/CMakeLists.txt
0 → 100644
View file @
764b3a75
add_library
(
http STATIC
http_client.cc
http_server.cc
)
target_link_libraries
(
http PUBLIC decoder
)
runtime/core/http/http_client.cc
0 → 100644
View file @
764b3a75
// Copyright (c) 2023 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "http/http_client.h"
#include "boost/json/src.hpp"
#include "utils/log.h"
namespace
wenet
{
namespace
beast
=
boost
::
beast
;
// from <boost/beast.hpp>
namespace
http
=
beast
::
http
;
// from <boost/beast/http.hpp>
namespace
net
=
boost
::
asio
;
// from <boost/asio.hpp>
using
tcp
=
net
::
ip
::
tcp
;
// from <boost/asio/ip/tcp.hpp>
namespace
json
=
boost
::
json
;
HttpClient
::
HttpClient
(
const
std
::
string
&
hostname
,
int
port
)
:
hostname_
(
hostname
),
port_
(
port
)
{
Connect
();
}
void
HttpClient
::
Connect
()
{
tcp
::
resolver
resolver
{
ioc_
};
// Look up the domain name
auto
const
results
=
resolver
.
resolve
(
hostname_
,
std
::
to_string
(
port_
));
stream_
.
connect
(
results
);
}
void
HttpClient
::
SendBinaryData
(
const
void
*
data
,
size_t
size
)
{
try
{
json
::
value
start_tag
=
{{
"nbest"
,
nbest_
},
{
"continuous_decoding"
,
continuous_decoding_
}};
std
::
string
config
=
json
::
serialize
(
start_tag
);
req_
.
set
(
"config"
,
config
);
std
::
size_t
encode_size
=
beast
::
detail
::
base64
::
encoded_size
(
size
);
char
encode_data
[
encode_size
];
// NOLINT
beast
::
detail
::
base64
::
encode
(
encode_data
,
data
,
size
);
req_
.
body
()
=
encode_data
;
req_
.
prepare_payload
();
http
::
write
(
stream_
,
req_
,
ec_
);
http
::
read
(
stream_
,
buffer_
,
res_
);
std
::
string
message
=
res_
.
body
();
json
::
object
obj
=
json
::
parse
(
message
).
as_object
();
LOG
(
INFO
)
<<
message
;
}
catch
(
std
::
exception
const
&
e
)
{
LOG
(
ERROR
)
<<
e
.
what
();
}
stream_
.
socket
().
shutdown
(
tcp
::
socket
::
shutdown_both
,
ec_
);
}
}
// namespace wenet
runtime/core/http/http_client.h
0 → 100644
View file @
764b3a75
// Copyright (c) 2023 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HTTP_HTTP_CLIENT_H_
#define HTTP_HTTP_CLIENT_H_
#include <iostream>
#include <memory>
#include <string>
#include <thread>
#include <boost/asio/connect.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <boost/beast/core.hpp>
#include <boost/beast/core/detail/base64.hpp>
#include <boost/beast/http.hpp>
#include <boost/beast/version.hpp>
#include "utils/utils.h"
namespace
wenet
{
namespace
beast
=
boost
::
beast
;
// from <boost/beast.hpp>
namespace
http
=
beast
::
http
;
// from <boost/beast/http.hpp>
namespace
net
=
boost
::
asio
;
// from <boost/asio.hpp>
using
tcp
=
net
::
ip
::
tcp
;
// from <boost/asio/ip/tcp.hpp>
class
HttpClient
{
public:
HttpClient
(
const
std
::
string
&
host
,
int
port
);
void
SendBinaryData
(
const
void
*
data
,
size_t
size
);
void
set_nbest
(
int
nbest
)
{
nbest_
=
nbest
;
}
private:
void
Connect
();
std
::
string
hostname_
;
int
port_
;
std
::
string
target_
=
"/"
;
int
version_
=
11
;
int
nbest_
=
1
;
const
bool
continuous_decoding_
=
false
;
net
::
io_context
ioc_
;
beast
::
tcp_stream
stream_
{
ioc_
};
beast
::
flat_buffer
buffer_
;
http
::
request
<
http
::
string_body
>
req_
{
http
::
verb
::
get
,
target_
,
version_
};
http
::
response
<
http
::
string_body
>
res_
{
http
::
status
::
ok
,
version_
};
beast
::
error_code
ec_
;
WENET_DISALLOW_COPY_AND_ASSIGN
(
HttpClient
);
};
}
// namespace wenet
#endif // HTTP_HTTP_CLIENT_H_
runtime/core/http/http_server.cc
0 → 100644
View file @
764b3a75
// Copyright (c) 2023 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "http/http_server.h"
#include <thread>
#include <utility>
#include <vector>
#include "boost/json/src.hpp"
#include "utils/log.h"
namespace
wenet
{
namespace
beast
=
boost
::
beast
;
// from <boost/beast.hpp>
namespace
http
=
beast
::
http
;
// from <boost/beast/http.hpp>
namespace
net
=
boost
::
asio
;
// from <boost/asio.hpp>
using
tcp
=
boost
::
asio
::
ip
::
tcp
;
// from <boost/asio/ip/tcp.hpp>
namespace
json
=
boost
::
json
;
ConnectionHandler
::
ConnectionHandler
(
tcp
::
socket
&&
socket
,
std
::
shared_ptr
<
FeaturePipelineConfig
>
feature_config
,
std
::
shared_ptr
<
DecodeOptions
>
decode_config
,
std
::
shared_ptr
<
DecodeResource
>
decode_resource
)
:
socket_
(
std
::
move
(
socket
)),
feature_config_
(
std
::
move
(
feature_config
)),
decode_config_
(
std
::
move
(
decode_config
)),
decode_resource_
(
std
::
move
(
decode_resource
)),
req_
(
std
::
make_shared
<
http
::
request
<
http
::
string_body
>>
(
http
::
verb
::
post
,
target_
,
version_
)),
res_
(
std
::
make_shared
<
http
::
response
<
http
::
string_body
>>
(
http
::
status
::
ok
,
version_
))
{}
void
ConnectionHandler
::
OnSpeechStart
()
{
feature_pipeline_
=
std
::
make_shared
<
FeaturePipeline
>
(
*
feature_config_
);
decoder_
=
std
::
make_shared
<
AsrDecoder
>
(
feature_pipeline_
,
decode_resource_
,
*
decode_config_
);
// Start decoder thread
decode_thread_
=
std
::
make_shared
<
std
::
thread
>
(
&
ConnectionHandler
::
DecodeThreadFunc
,
this
);
}
void
ConnectionHandler
::
OnSpeechEnd
()
{
if
(
feature_pipeline_
!=
nullptr
)
{
feature_pipeline_
->
set_input_finished
();
}
}
void
ConnectionHandler
::
OnFinalResult
(
const
std
::
string
&
result
)
{
LOG
(
INFO
)
<<
"Final result: "
<<
result
;
json
::
value
rv
=
{
{
"status"
,
"ok"
},
{
"type"
,
"final_result"
},
{
"nbest"
,
result
}};
std
::
string
message
=
json
::
serialize
(
rv
);
res_
.
get
()
->
body
()
=
message
;
http
::
write
(
socket_
,
*
res_
.
get
(),
ec_
);
}
void
ConnectionHandler
::
OnSpeechData
(
const
std
::
string
&
message
)
{
std
::
size_t
decode_size
=
beast
::
detail
::
base64
::
decoded_size
(
message
.
length
());
int
num_samples
=
decode_size
/
sizeof
(
int16_t
);
int16_t
decode_data
[
num_samples
];
// NOLINT
beast
::
detail
::
base64
::
decode
(
decode_data
,
message
.
c_str
(),
message
.
length
());
// Read binary PCM data
VLOG
(
2
)
<<
"Received "
<<
num_samples
<<
" samples"
;
CHECK
(
feature_pipeline_
!=
nullptr
);
CHECK
(
decoder_
!=
nullptr
);
feature_pipeline_
->
AcceptWaveform
(
decode_data
,
num_samples
);
}
std
::
string
ConnectionHandler
::
SerializeResult
(
bool
finish
)
{
json
::
array
nbest
;
for
(
const
DecodeResult
&
path
:
decoder_
->
result
())
{
json
::
object
jpath
({{
"sentence"
,
path
.
sentence
}});
if
(
finish
)
{
json
::
array
word_pieces
;
for
(
const
WordPiece
&
word_piece
:
path
.
word_pieces
)
{
json
::
object
jword_piece
({{
"word"
,
word_piece
.
word
},
{
"start"
,
word_piece
.
start
},
{
"end"
,
word_piece
.
end
}});
word_pieces
.
emplace_back
(
jword_piece
);
}
jpath
.
emplace
(
"word_pieces"
,
word_pieces
);
}
nbest
.
emplace_back
(
jpath
);
if
(
nbest
.
size
()
==
nbest_
)
{
break
;
}
}
return
json
::
serialize
(
nbest
);
}
void
ConnectionHandler
::
DecodeThreadFunc
()
{
try
{
while
(
true
)
{
DecodeState
state
=
decoder_
->
Decode
();
if
(
state
==
DecodeState
::
kEndFeats
||
state
==
DecodeState
::
kEndpoint
)
{
decoder_
->
Rescoring
();
std
::
string
result
=
SerializeResult
(
true
);
OnFinalResult
(
result
);
break
;
}
}
}
catch
(
std
::
exception
const
&
e
)
{
LOG
(
ERROR
)
<<
e
.
what
();
}
}
void
ConnectionHandler
::
OnError
(
const
std
::
string
&
message
)
{
json
::
value
rv
=
{{
"status"
,
"failed"
},
{
"message"
,
message
}};
res_
.
get
()
->
body
()
=
json
::
serialize
(
rv
);
http
::
write
(
socket_
,
*
res_
.
get
(),
ec_
);
// Send a TCP shutdown
socket_
.
shutdown
(
tcp
::
socket
::
shutdown_send
,
ec_
);
}
void
ConnectionHandler
::
OnText
(
const
std
::
string
&
message
)
{
LOG
(
INFO
)
<<
message
;
json
::
value
v
=
json
::
parse
(
message
);
if
(
v
.
is_object
())
{
json
::
object
obj
=
v
.
get_object
();
if
(
obj
.
find
(
"nbest"
)
!=
obj
.
end
())
{
if
(
obj
[
"nbest"
].
is_int64
())
{
nbest_
=
obj
[
"nbest"
].
as_int64
();
}
else
{
OnError
(
"integer is expected for nbest option"
);
}
}
}
else
{
OnError
(
"Wrong protocol"
);
}
}
void
ConnectionHandler
::
operator
()()
{
try
{
http
::
read
(
socket_
,
buffer_
,
*
req_
.
get
(),
ec_
);
if
(
ec_
)
{
LOG
(
ERROR
)
<<
ec_
;
}
else
{
OnText
(
req_
.
get
()
->
base
()[
"config"
].
to_string
());
OnSpeechStart
();
OnSpeechData
(
req_
.
get
()
->
body
());
OnSpeechEnd
();
}
LOG
(
INFO
)
<<
"Read all pcm data, wait for decoding thread"
;
if
(
decode_thread_
!=
nullptr
)
{
decode_thread_
->
join
();
}
}
catch
(
beast
::
system_error
const
&
se
)
{
LOG
(
INFO
)
<<
se
.
code
().
message
();
if
(
decode_thread_
!=
nullptr
)
{
decode_thread_
->
join
();
}
}
catch
(
std
::
exception
const
&
e
)
{
LOG
(
ERROR
)
<<
e
.
what
();
}
socket_
.
shutdown
(
tcp
::
socket
::
shutdown_send
,
ec_
);
}
void
HttpServer
::
Start
()
{
try
{
auto
const
address
=
net
::
ip
::
make_address
(
"0.0.0.0"
);
tcp
::
acceptor
acceptor
{
ioc_
,
{
address
,
static_cast
<
uint16_t
>
(
port_
)}};
for
(;;)
{
// This will receive the new connection
tcp
::
socket
socket
{
ioc_
};
// Block until we get a connection
acceptor
.
accept
(
socket
);
// Launch the session, transferring ownership of the socket
ConnectionHandler
handler
(
std
::
move
(
socket
),
feature_config_
,
decode_config_
,
decode_resource_
);
std
::
thread
t
(
std
::
move
(
handler
));
t
.
detach
();
}
}
catch
(
const
std
::
exception
&
e
)
{
LOG
(
FATAL
)
<<
e
.
what
();
}
}
}
// namespace wenet
runtime/core/http/http_server.h
0 → 100644
View file @
764b3a75
// Copyright (c) 2023 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef HTTP_HTTP_SERVER_H_
#define HTTP_HTTP_SERVER_H_
#include <iostream>
#include <memory>
#include <string>
#include <thread>
#include <utility>
#include <boost/asio/ip/tcp.hpp>
#include <boost/beast/core.hpp>
#include <boost/beast/core/detail/base64.hpp>
#include <boost/beast/http.hpp>
#include <boost/beast/version.hpp>
#include <boost/config.hpp>
#include "decoder/asr_decoder.h"
#include "frontend/feature_pipeline.h"
#include "utils/log.h"
namespace
wenet
{
namespace
beast
=
boost
::
beast
;
// from <boost/beast.hpp>
namespace
http
=
beast
::
http
;
// from <boost/beast/http.hpp>
namespace
net
=
boost
::
asio
;
// from <boost/asio.hpp>
using
tcp
=
boost
::
asio
::
ip
::
tcp
;
// from <boost/asio/ip/tcp.hpp>
class
ConnectionHandler
{
public:
ConnectionHandler
(
tcp
::
socket
&&
socket
,
std
::
shared_ptr
<
FeaturePipelineConfig
>
feature_config
,
std
::
shared_ptr
<
DecodeOptions
>
decode_config
,
std
::
shared_ptr
<
DecodeResource
>
decode_resource_
);
void
operator
()();
private:
void
OnSpeechStart
();
void
OnSpeechEnd
();
void
OnText
(
const
std
::
string
&
message
);
void
OnSpeechData
(
const
std
::
string
&
message
);
void
OnError
(
const
std
::
string
&
message
);
void
OnFinalResult
(
const
std
::
string
&
result
);
void
DecodeThreadFunc
();
std
::
string
SerializeResult
(
bool
finish
);
std
::
string
target_
=
"/"
;
int
version_
=
11
;
const
bool
continuous_decoding_
=
false
;
int
nbest_
=
1
;
tcp
::
socket
socket_
;
beast
::
flat_buffer
buffer_
;
beast
::
error_code
ec_
;
std
::
shared_ptr
<
http
::
request
<
http
::
string_body
>>
req_
;
std
::
shared_ptr
<
http
::
response
<
http
::
string_body
>>
res_
;
std
::
shared_ptr
<
FeaturePipelineConfig
>
feature_config_
;
std
::
shared_ptr
<
DecodeOptions
>
decode_config_
;
std
::
shared_ptr
<
DecodeResource
>
decode_resource_
;
std
::
shared_ptr
<
FeaturePipeline
>
feature_pipeline_
=
nullptr
;
std
::
shared_ptr
<
AsrDecoder
>
decoder_
=
nullptr
;
std
::
shared_ptr
<
std
::
thread
>
decode_thread_
=
nullptr
;
};
class
HttpServer
{
public:
HttpServer
(
int
port
,
std
::
shared_ptr
<
FeaturePipelineConfig
>
feature_config
,
std
::
shared_ptr
<
DecodeOptions
>
decode_config
,
std
::
shared_ptr
<
DecodeResource
>
decode_resource
)
:
port_
(
port
),
feature_config_
(
std
::
move
(
feature_config
)),
decode_config_
(
std
::
move
(
decode_config
)),
decode_resource_
(
std
::
move
(
decode_resource
))
{}
void
Start
();
private:
int
port_
;
// The io_context is required for all I/O
net
::
io_context
ioc_
{
1
};
std
::
shared_ptr
<
FeaturePipelineConfig
>
feature_config_
;
std
::
shared_ptr
<
DecodeOptions
>
decode_config_
;
std
::
shared_ptr
<
DecodeResource
>
decode_resource_
;
WENET_DISALLOW_COPY_AND_ASSIGN
(
HttpServer
);
};
}
// namespace wenet
#endif // HTTP_HTTP_SERVER_H_
runtime/core/kaldi/CMakeLists.txt
0 → 100644
View file @
764b3a75
cmake_minimum_required
(
VERSION 3.10 FATAL_ERROR
)
project
(
kaldi
)
# include_directories() is called in the root CMakeLists.txt
add_library
(
kaldi-util
base/kaldi-error.cc
base/kaldi-math.cc
util/kaldi-io.cc
util/parse-options.cc
util/simple-io-funcs.cc
util/text-utils.cc
)
target_link_libraries
(
kaldi-util PUBLIC utils
)
add_library
(
kaldi-decoder
lat/determinize-lattice-pruned.cc
lat/lattice-functions.cc
decoder/lattice-faster-decoder.cc
decoder/lattice-faster-online-decoder.cc
)
target_link_libraries
(
kaldi-decoder PUBLIC kaldi-util
)
if
(
GRAPH_TOOLS
)
# Arpa binary
add_executable
(
arpa2fst
lm/arpa-file-parser.cc
lm/arpa-lm-compiler.cc
lmbin/arpa2fst.cc
)
target_link_libraries
(
arpa2fst PUBLIC kaldi-util
)
# FST tools binary
set
(
FST_BINS
fstaddselfloops
fstdeterminizestar
fstisstochastic
fstminimizeencoded
fsttablecompose
)
if
(
NOT MSVC
)
# dl is for dynamic linking, otherwise there is a linking error on linux
link_libraries
(
dl
)
endif
()
foreach
(
name IN LISTS FST_BINS
)
add_executable
(
${
name
}
fstbin/
${
name
}
.cc
fstext/kaldi-fst-io.cc
)
target_link_libraries
(
${
name
}
PUBLIC kaldi-util
)
endforeach
()
endif
()
runtime/core/kaldi/README.md
0 → 100644
View file @
764b3a75
We use Kaldi decoder to implement TLG based language model integration,
so we copied related files to this directory.
The main changes are:
1.
To minimize the change, we use the same directories tree as Kaldi.
2.
We replace Kaldi log system with glog in the following way.
```
c++
#define KALDI_WARN \
google::LogMessage(__FILE__, __LINE__, google::GLOG_WARNING).stream()
#define KALDI_ERR \
google::LogMessage(__FILE__, __LINE__, google::GLOG_ERROR).stream()
#define KALDI_INFO \
google::LogMessage(__FILE__, __LINE__, google::GLOG_INFO).stream()
#define KALDI_VLOG(v) VLOG(v)
#define KALDI_ASSERT(condition) CHECK(condition)
```
3.
We lint all the files to satisfy the lint in WeNet.
runtime/core/kaldi/base/io-funcs-inl.h
0 → 100644
View file @
764b3a75
// base/io-funcs-inl.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
// Jan Silovsky; Yanmin Qian;
// Johns Hopkins University (Author: Daniel Povey)
// 2016 Xiaohui Zhang
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_BASE_IO_FUNCS_INL_H_
#define KALDI_BASE_IO_FUNCS_INL_H_ 1
// Do not include this file directly. It is included by base/io-funcs.h
#include <limits>
#include <vector>
#include <utility>
namespace
kaldi
{
// Template that covers integers.
template
<
class
T
>
void
WriteBasicType
(
std
::
ostream
&
os
,
bool
binary
,
T
t
)
{
// Compile time assertion that this is not called with a wrong type.
KALDI_ASSERT_IS_INTEGER_TYPE
(
T
);
if
(
binary
)
{
char
len_c
=
(
std
::
numeric_limits
<
T
>::
is_signed
?
1
:
-
1
)
*
static_cast
<
char
>
(
sizeof
(
t
));
os
.
put
(
len_c
);
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
t
),
sizeof
(
t
));
}
else
{
if
(
sizeof
(
t
)
==
1
)
os
<<
static_cast
<
int16
>
(
t
)
<<
" "
;
else
os
<<
t
<<
" "
;
}
if
(
os
.
fail
())
{
KALDI_ERR
<<
"Write failure in WriteBasicType."
;
}
}
// Template that covers integers.
template
<
class
T
>
inline
void
ReadBasicType
(
std
::
istream
&
is
,
bool
binary
,
T
*
t
)
{
KALDI_PARANOID_ASSERT
(
t
!=
NULL
);
// Compile time assertion that this is not called with a wrong type.
KALDI_ASSERT_IS_INTEGER_TYPE
(
T
);
if
(
binary
)
{
int
len_c_in
=
is
.
get
();
if
(
len_c_in
==
-
1
)
KALDI_ERR
<<
"ReadBasicType: encountered end of stream."
;
char
len_c
=
static_cast
<
char
>
(
len_c_in
),
len_c_expected
=
(
std
::
numeric_limits
<
T
>::
is_signed
?
1
:
-
1
)
*
static_cast
<
char
>
(
sizeof
(
*
t
));
if
(
len_c
!=
len_c_expected
)
{
KALDI_ERR
<<
"ReadBasicType: did not get expected integer type, "
<<
static_cast
<
int
>
(
len_c
)
<<
" vs. "
<<
static_cast
<
int
>
(
len_c_expected
)
<<
". You can change this code to successfully"
<<
" read it later, if needed."
;
// insert code here to read "wrong" type. Might have a switch statement.
}
is
.
read
(
reinterpret_cast
<
char
*>
(
t
),
sizeof
(
*
t
));
}
else
{
if
(
sizeof
(
*
t
)
==
1
)
{
int16
i
;
is
>>
i
;
*
t
=
i
;
}
else
{
is
>>
*
t
;
}
}
if
(
is
.
fail
())
{
KALDI_ERR
<<
"Read failure in ReadBasicType, file position is "
<<
is
.
tellg
()
<<
", next char is "
<<
is
.
peek
();
}
}
// Template that covers integers.
template
<
class
T
>
inline
void
WriteIntegerPairVector
(
std
::
ostream
&
os
,
bool
binary
,
const
std
::
vector
<
std
::
pair
<
T
,
T
>
>
&
v
)
{
// Compile time assertion that this is not called with a wrong type.
KALDI_ASSERT_IS_INTEGER_TYPE
(
T
);
if
(
binary
)
{
char
sz
=
sizeof
(
T
);
// this is currently just a check.
os
.
write
(
&
sz
,
1
);
int32
vecsz
=
static_cast
<
int32
>
(
v
.
size
());
KALDI_ASSERT
((
size_t
)
vecsz
==
v
.
size
());
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
vecsz
),
sizeof
(
vecsz
));
if
(
vecsz
!=
0
)
{
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
(
v
[
0
])),
sizeof
(
T
)
*
vecsz
*
2
);
}
}
else
{
// focus here is on prettiness of text form rather than
// efficiency of reading-in.
// reading-in is dominated by low-level operations anyway:
// for efficiency use binary.
os
<<
"[ "
;
typename
std
::
vector
<
std
::
pair
<
T
,
T
>
>::
const_iterator
iter
=
v
.
begin
(),
end
=
v
.
end
();
for
(;
iter
!=
end
;
++
iter
)
{
if
(
sizeof
(
T
)
==
1
)
os
<<
static_cast
<
int16
>
(
iter
->
first
)
<<
','
<<
static_cast
<
int16
>
(
iter
->
second
)
<<
' '
;
else
os
<<
iter
->
first
<<
','
<<
iter
->
second
<<
' '
;
}
os
<<
"]
\n
"
;
}
if
(
os
.
fail
())
{
KALDI_ERR
<<
"Write failure in WriteIntegerPairVector."
;
}
}
// Template that covers integers.
template
<
class
T
>
inline
void
ReadIntegerPairVector
(
std
::
istream
&
is
,
bool
binary
,
std
::
vector
<
std
::
pair
<
T
,
T
>
>
*
v
)
{
KALDI_ASSERT_IS_INTEGER_TYPE
(
T
);
KALDI_ASSERT
(
v
!=
NULL
);
if
(
binary
)
{
int
sz
=
is
.
peek
();
if
(
sz
==
sizeof
(
T
))
{
is
.
get
();
}
else
{
// this is currently just a check.
KALDI_ERR
<<
"ReadIntegerPairVector: expected to see type of size "
<<
sizeof
(
T
)
<<
", saw instead "
<<
sz
<<
", at file position "
<<
is
.
tellg
();
}
int32
vecsz
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
vecsz
),
sizeof
(
vecsz
));
if
(
is
.
fail
()
||
vecsz
<
0
)
goto
bad
;
v
->
resize
(
vecsz
);
if
(
vecsz
>
0
)
{
is
.
read
(
reinterpret_cast
<
char
*>
(
&
((
*
v
)[
0
])),
sizeof
(
T
)
*
vecsz
*
2
);
}
}
else
{
std
::
vector
<
std
::
pair
<
T
,
T
>
>
tmp_v
;
// use temporary so v doesn't use
// extra memory due to resizing.
is
>>
std
::
ws
;
if
(
is
.
peek
()
!=
static_cast
<
int
>
(
'['
))
{
KALDI_ERR
<<
"ReadIntegerPairVector: expected to see [, saw "
<<
is
.
peek
()
<<
", at file position "
<<
is
.
tellg
();
}
is
.
get
();
// consume the '['.
is
>>
std
::
ws
;
// consume whitespace.
while
(
is
.
peek
()
!=
static_cast
<
int
>
(
']'
))
{
if
(
sizeof
(
T
)
==
1
)
{
// read/write chars as numbers.
int16
next_t1
,
next_t2
;
is
>>
next_t1
;
if
(
is
.
fail
())
goto
bad
;
if
(
is
.
peek
()
!=
static_cast
<
int
>
(
','
))
KALDI_ERR
<<
"ReadIntegerPairVector: expected to see ',', saw "
<<
is
.
peek
()
<<
", at file position "
<<
is
.
tellg
();
is
.
get
();
// consume the ','.
is
>>
next_t2
>>
std
::
ws
;
if
(
is
.
fail
())
goto
bad
;
else
tmp_v
.
push_back
(
std
::
make_pair
((
T
)
next_t1
,
(
T
)
next_t2
));
}
else
{
T
next_t1
,
next_t2
;
is
>>
next_t1
;
if
(
is
.
fail
())
goto
bad
;
if
(
is
.
peek
()
!=
static_cast
<
int
>
(
','
))
KALDI_ERR
<<
"ReadIntegerPairVector: expected to see ',', saw "
<<
is
.
peek
()
<<
", at file position "
<<
is
.
tellg
();
is
.
get
();
// consume the ','.
is
>>
next_t2
>>
std
::
ws
;
if
(
is
.
fail
())
goto
bad
;
else
tmp_v
.
push_back
(
std
::
pair
<
T
,
T
>
(
next_t1
,
next_t2
));
}
}
is
.
get
();
// get the final ']'.
*
v
=
tmp_v
;
// could use std::swap to use less temporary memory, but this
// uses less permanent memory.
}
if
(
!
is
.
fail
())
return
;
bad:
KALDI_ERR
<<
"ReadIntegerPairVector: read failure at file position "
<<
is
.
tellg
();
}
template
<
class
T
>
inline
void
WriteIntegerVector
(
std
::
ostream
&
os
,
bool
binary
,
const
std
::
vector
<
T
>
&
v
)
{
// Compile time assertion that this is not called with a wrong type.
KALDI_ASSERT_IS_INTEGER_TYPE
(
T
);
if
(
binary
)
{
char
sz
=
sizeof
(
T
);
// this is currently just a check.
os
.
write
(
&
sz
,
1
);
int32
vecsz
=
static_cast
<
int32
>
(
v
.
size
());
KALDI_ASSERT
((
size_t
)
vecsz
==
v
.
size
());
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
vecsz
),
sizeof
(
vecsz
));
if
(
vecsz
!=
0
)
{
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
(
v
[
0
])),
sizeof
(
T
)
*
vecsz
);
}
}
else
{
// focus here is on prettiness of text form rather than
// efficiency of reading-in.
// reading-in is dominated by low-level operations anyway:
// for efficiency use binary.
os
<<
"[ "
;
typename
std
::
vector
<
T
>::
const_iterator
iter
=
v
.
begin
(),
end
=
v
.
end
();
for
(;
iter
!=
end
;
++
iter
)
{
if
(
sizeof
(
T
)
==
1
)
os
<<
static_cast
<
int16
>
(
*
iter
)
<<
" "
;
else
os
<<
*
iter
<<
" "
;
}
os
<<
"]
\n
"
;
}
if
(
os
.
fail
())
{
KALDI_ERR
<<
"Write failure in WriteIntegerVector."
;
}
}
template
<
class
T
>
inline
void
ReadIntegerVector
(
std
::
istream
&
is
,
bool
binary
,
std
::
vector
<
T
>
*
v
)
{
KALDI_ASSERT_IS_INTEGER_TYPE
(
T
);
KALDI_ASSERT
(
v
!=
NULL
);
if
(
binary
)
{
int
sz
=
is
.
peek
();
if
(
sz
==
sizeof
(
T
))
{
is
.
get
();
}
else
{
// this is currently just a check.
KALDI_ERR
<<
"ReadIntegerVector: expected to see type of size "
<<
sizeof
(
T
)
<<
", saw instead "
<<
sz
<<
", at file position "
<<
is
.
tellg
();
}
int32
vecsz
;
is
.
read
(
reinterpret_cast
<
char
*>
(
&
vecsz
),
sizeof
(
vecsz
));
if
(
is
.
fail
()
||
vecsz
<
0
)
goto
bad
;
v
->
resize
(
vecsz
);
if
(
vecsz
>
0
)
{
is
.
read
(
reinterpret_cast
<
char
*>
(
&
((
*
v
)[
0
])),
sizeof
(
T
)
*
vecsz
);
}
}
else
{
std
::
vector
<
T
>
tmp_v
;
// use temporary so v doesn't use extra memory
// due to resizing.
is
>>
std
::
ws
;
if
(
is
.
peek
()
!=
static_cast
<
int
>
(
'['
))
{
KALDI_ERR
<<
"ReadIntegerVector: expected to see [, saw "
<<
is
.
peek
()
<<
", at file position "
<<
is
.
tellg
();
}
is
.
get
();
// consume the '['.
is
>>
std
::
ws
;
// consume whitespace.
while
(
is
.
peek
()
!=
static_cast
<
int
>
(
']'
))
{
if
(
sizeof
(
T
)
==
1
)
{
// read/write chars as numbers.
int16
next_t
;
is
>>
next_t
>>
std
::
ws
;
if
(
is
.
fail
())
goto
bad
;
else
tmp_v
.
push_back
((
T
)
next_t
);
}
else
{
T
next_t
;
is
>>
next_t
>>
std
::
ws
;
if
(
is
.
fail
())
goto
bad
;
else
tmp_v
.
push_back
(
next_t
);
}
}
is
.
get
();
// get the final ']'.
*
v
=
tmp_v
;
// could use std::swap to use less temporary memory, but this
// uses less permanent memory.
}
if
(
!
is
.
fail
())
return
;
bad:
KALDI_ERR
<<
"ReadIntegerVector: read failure at file position "
<<
is
.
tellg
();
}
// Initialize an opened stream for writing by writing an optional binary
// header and modifying the floating-point precision.
inline
void
InitKaldiOutputStream
(
std
::
ostream
&
os
,
bool
binary
)
{
// This does not throw exceptions (does not check for errors).
if
(
binary
)
{
os
.
put
(
'\0'
);
os
.
put
(
'B'
);
}
// Note, in non-binary mode we may at some point want to mess with
// the precision a bit.
// 7 is a bit more than the precision of float..
if
(
os
.
precision
()
<
7
)
os
.
precision
(
7
);
}
/// Initialize an opened stream for reading by detecting the binary header and
// setting the "binary" value appropriately.
inline
bool
InitKaldiInputStream
(
std
::
istream
&
is
,
bool
*
binary
)
{
// Sets the 'binary' variable.
// Throws exception in the very unusual situation that stream
// starts with '\0' but not then 'B'.
if
(
is
.
peek
()
==
'\0'
)
{
// seems to be binary
is
.
get
();
if
(
is
.
peek
()
!=
'B'
)
{
return
false
;
}
is
.
get
();
*
binary
=
true
;
return
true
;
}
else
{
*
binary
=
false
;
return
true
;
}
}
}
// end namespace kaldi.
#endif // KALDI_BASE_IO_FUNCS_INL_H_
runtime/core/kaldi/base/io-funcs.cc
0 → 100644
View file @
764b3a75
// base/io-funcs.cc
// Copyright 2009-2011 Microsoft Corporation; Saarland University
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/io-funcs.h"
#include "base/kaldi-math.h"
namespace
kaldi
{
template
<
>
void
WriteBasicType
<
bool
>
(
std
::
ostream
&
os
,
bool
binary
,
bool
b
)
{
os
<<
(
b
?
"T"
:
"F"
);
if
(
!
binary
)
os
<<
" "
;
if
(
os
.
fail
())
KALDI_ERR
<<
"Write failure in WriteBasicType<bool>"
;
}
template
<
>
void
ReadBasicType
<
bool
>
(
std
::
istream
&
is
,
bool
binary
,
bool
*
b
)
{
KALDI_PARANOID_ASSERT
(
b
!=
NULL
);
if
(
!
binary
)
is
>>
std
::
ws
;
// eat up whitespace.
char
c
=
is
.
peek
();
if
(
c
==
'T'
)
{
*
b
=
true
;
is
.
get
();
}
else
if
(
c
==
'F'
)
{
*
b
=
false
;
is
.
get
();
}
else
{
KALDI_ERR
<<
"Read failure in ReadBasicType<bool>, file position is "
<<
is
.
tellg
()
<<
", next char is "
<<
CharToString
(
c
);
}
}
template
<
>
void
WriteBasicType
<
float
>
(
std
::
ostream
&
os
,
bool
binary
,
float
f
)
{
if
(
binary
)
{
char
c
=
sizeof
(
f
);
os
.
put
(
c
);
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
f
),
sizeof
(
f
));
}
else
{
os
<<
f
<<
" "
;
}
}
template
<
>
void
WriteBasicType
<
double
>
(
std
::
ostream
&
os
,
bool
binary
,
double
f
)
{
if
(
binary
)
{
char
c
=
sizeof
(
f
);
os
.
put
(
c
);
os
.
write
(
reinterpret_cast
<
const
char
*>
(
&
f
),
sizeof
(
f
));
}
else
{
os
<<
f
<<
" "
;
}
}
template
<
>
void
ReadBasicType
<
float
>
(
std
::
istream
&
is
,
bool
binary
,
float
*
f
)
{
KALDI_PARANOID_ASSERT
(
f
!=
NULL
);
if
(
binary
)
{
double
d
;
int
c
=
is
.
peek
();
if
(
c
==
sizeof
(
*
f
))
{
is
.
get
();
is
.
read
(
reinterpret_cast
<
char
*>
(
f
),
sizeof
(
*
f
));
}
else
if
(
c
==
sizeof
(
d
))
{
ReadBasicType
(
is
,
binary
,
&
d
);
*
f
=
d
;
}
else
{
KALDI_ERR
<<
"ReadBasicType: expected float, saw "
<<
is
.
peek
()
<<
", at file position "
<<
is
.
tellg
();
}
}
else
{
is
>>
*
f
;
}
if
(
is
.
fail
())
{
KALDI_ERR
<<
"ReadBasicType: failed to read, at file position "
<<
is
.
tellg
();
}
}
template
<
>
void
ReadBasicType
<
double
>
(
std
::
istream
&
is
,
bool
binary
,
double
*
d
)
{
KALDI_PARANOID_ASSERT
(
d
!=
NULL
);
if
(
binary
)
{
float
f
;
int
c
=
is
.
peek
();
if
(
c
==
sizeof
(
*
d
))
{
is
.
get
();
is
.
read
(
reinterpret_cast
<
char
*>
(
d
),
sizeof
(
*
d
));
}
else
if
(
c
==
sizeof
(
f
))
{
ReadBasicType
(
is
,
binary
,
&
f
);
*
d
=
f
;
}
else
{
KALDI_ERR
<<
"ReadBasicType: expected float, saw "
<<
is
.
peek
()
<<
", at file position "
<<
is
.
tellg
();
}
}
else
{
is
>>
*
d
;
}
if
(
is
.
fail
())
{
KALDI_ERR
<<
"ReadBasicType: failed to read, at file position "
<<
is
.
tellg
();
}
}
void
CheckToken
(
const
char
*
token
)
{
if
(
*
token
==
'\0'
)
KALDI_ERR
<<
"Token is empty (not a valid token)"
;
const
char
*
orig_token
=
token
;
while
(
*
token
!=
'\0'
)
{
if
(
::
isspace
(
*
token
))
KALDI_ERR
<<
"Token is not a valid token (contains space): '"
<<
orig_token
<<
"'"
;
token
++
;
}
}
void
WriteToken
(
std
::
ostream
&
os
,
bool
binary
,
const
char
*
token
)
{
// binary mode is ignored;
// we use space as termination character in either case.
KALDI_ASSERT
(
token
!=
NULL
);
CheckToken
(
token
);
// make sure it's valid (can be read back)
os
<<
token
<<
" "
;
if
(
os
.
fail
())
{
KALDI_ERR
<<
"Write failure in WriteToken."
;
}
}
int
Peek
(
std
::
istream
&
is
,
bool
binary
)
{
if
(
!
binary
)
is
>>
std
::
ws
;
// eat up whitespace.
return
is
.
peek
();
}
void
WriteToken
(
std
::
ostream
&
os
,
bool
binary
,
const
std
::
string
&
token
)
{
WriteToken
(
os
,
binary
,
token
.
c_str
());
}
void
ReadToken
(
std
::
istream
&
is
,
bool
binary
,
std
::
string
*
str
)
{
KALDI_ASSERT
(
str
!=
NULL
);
if
(
!
binary
)
is
>>
std
::
ws
;
// consume whitespace.
is
>>
*
str
;
if
(
is
.
fail
())
{
KALDI_ERR
<<
"ReadToken, failed to read token at file position "
<<
is
.
tellg
();
}
if
(
!
isspace
(
is
.
peek
()))
{
KALDI_ERR
<<
"ReadToken, expected space after token, saw instead "
<<
CharToString
(
static_cast
<
char
>
(
is
.
peek
()))
<<
", at file position "
<<
is
.
tellg
();
}
is
.
get
();
// consume the space.
}
int
PeekToken
(
std
::
istream
&
is
,
bool
binary
)
{
if
(
!
binary
)
is
>>
std
::
ws
;
// consume whitespace.
bool
read_bracket
;
if
(
static_cast
<
char
>
(
is
.
peek
())
==
'<'
)
{
read_bracket
=
true
;
is
.
get
();
}
else
{
read_bracket
=
false
;
}
int
ans
=
is
.
peek
();
if
(
read_bracket
)
{
if
(
!
is
.
unget
())
{
// Clear the bad bit. This code can be (and is in fact) reached, since the
// C++ standard does not guarantee that a call to unget() must succeed.
is
.
clear
();
}
}
return
ans
;
}
void
ExpectToken
(
std
::
istream
&
is
,
bool
binary
,
const
char
*
token
)
{
int
pos_at_start
=
is
.
tellg
();
KALDI_ASSERT
(
token
!=
NULL
);
CheckToken
(
token
);
// make sure it's valid (can be read back)
if
(
!
binary
)
is
>>
std
::
ws
;
// consume whitespace.
std
::
string
str
;
is
>>
str
;
is
.
get
();
// consume the space.
if
(
is
.
fail
())
{
KALDI_ERR
<<
"Failed to read token [started at file position "
<<
pos_at_start
<<
"], expected "
<<
token
;
}
// The second half of the '&&' expression below is so that if we're expecting
// "<Foo>", we will accept "Foo>" instead. This is so that the model-reading
// code will tolerate errors in PeekToken where is.unget() failed; search for
// is.clear() in PeekToken() for an explanation.
if
(
strcmp
(
str
.
c_str
(),
token
)
!=
0
&&
!
(
token
[
0
]
==
'<'
&&
strcmp
(
str
.
c_str
(),
token
+
1
)
==
0
))
{
KALDI_ERR
<<
"Expected token
\"
"
<<
token
<<
"
\"
, got instead
\"
"
<<
str
<<
"
\"
."
;
}
}
void
ExpectToken
(
std
::
istream
&
is
,
bool
binary
,
const
std
::
string
&
token
)
{
ExpectToken
(
is
,
binary
,
token
.
c_str
());
}
}
// end namespace kaldi
runtime/core/kaldi/base/io-funcs.h
0 → 100644
View file @
764b3a75
// base/io-funcs.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
// Jan Silovsky; Yanmin Qian
// 2016 Xiaohui Zhang
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_BASE_IO_FUNCS_H_
#define KALDI_BASE_IO_FUNCS_H_
// This header only contains some relatively low-level I/O functions.
// The full Kaldi I/O declarations are in ../util/kaldi-io.h
// and ../util/kaldi-table.h
// They were put in util/ in order to avoid making the Matrix library
// dependent on them.
#include <cctype>
#include <string>
#include <utility>
#include <vector>
#include "base/io-funcs-inl.h"
#include "base/kaldi-common.h"
namespace
kaldi
{
/*
This comment describes the Kaldi approach to I/O. All objects can be written
and read in two modes: binary and text. In addition we want to make the I/O
work if we redefine the typedef "BaseFloat" between floats and doubles.
We also want to have control over whitespace in text mode without affecting
the meaning of the file, for pretty-printing purposes.
Errors are handled by throwing a KaldiFatalError exception.
For integer and floating-point types (and boolean values):
WriteBasicType(std::ostream &, bool binary, const T&);
ReadBasicType(std::istream &, bool binary, T*);
and we expect these functions to be defined in such a way that they work when
the type T changes between float and double, so you can read float into double
and vice versa]. Note that for efficiency and space-saving reasons, the
Vector and Matrix classes do not use these functions [but they preserve the
type interchangeability in their own way]
For a class (or struct) C:
class C {
..
Write(std::ostream &, bool binary, [possibly extra optional args for
specific classes]) const; Read(std::istream &, bool binary, [possibly extra
optional args for specific classes]);
..
}
NOTE: The only actual optional args we used are the "add" arguments in
Vector/Matrix classes, which specify whether we should sum the data already
in the class with the data being read.
For types which are typedef's involving stl classes, I/O is as follows:
typedef std::vector<std::pair<A, B> > MyTypedefName;
The user should define something like:
WriteMyTypedefName(std::ostream &, bool binary, const MyTypedefName &t);
ReadMyTypedefName(std::ostream &, bool binary, MyTypedefName *t);
The user would have to write these functions.
For a type std::vector<T>:
void WriteIntegerVector(std::ostream &os, bool binary, const std::vector<T>
&v); void ReadIntegerVector(std::istream &is, bool binary, std::vector<T> *v);
For other types, e.g. vectors of pairs, the user should create a routine of
the type WriteMyTypedefName. This is to avoid introducing confusing templated
functions; we could easily create templated functions to handle most of these
cases but they would have to share the same name.
It also often happens that the user needs to write/read special tokens as part
of a file. These might be class headers, or separators/identifiers in the
class. We provide special functions for manipulating these. These special
tokens must be nonempty and must not contain any whitespace.
void WriteToken(std::ostream &os, bool binary, const char*);
void WriteToken(std::ostream &os, bool binary, const std::string & token);
int Peek(std::istream &is, bool binary);
void ReadToken(std::istream &is, bool binary, std::string *str);
void PeekToken(std::istream &is, bool binary, std::string *str);
WriteToken writes the token and one space (whether in binary or text mode).
Peek returns the first character of the next token, by consuming whitespace
(in text mode) and then returning the peek() character. It returns -1 at EOF;
it doesn't throw. It's useful if a class can have various forms based on
typedefs and virtual classes, and wants to know which version to read.
ReadToken allows the caller to obtain the next token. PeekToken works just
like ReadToken, but seeks back to the beginning of the token. A subsequent
call to ReadToken will read the same token again. This is useful when
different object types are written to the same file; using PeekToken one can
decide which of the objects to read.
There is currently no special functionality for writing/reading strings (where
the strings contain data rather than "special tokens" that are whitespace-free
and nonempty). This is because Kaldi is structured in such a way that strings
don't appear, except as OpenFst symbol table entries (and these have their own
format).
NOTE: you should not call ReadIntegerType and WriteIntegerType with types,
such as int and size_t, that are machine-independent -- at least not
if you want your file formats to port between machines. Use int32 and
int64 where necessary. There is no way to detect this using compile-time
assertions because C++ only keeps track of the internal representation of
the type.
*/
/// \addtogroup io_funcs_basic
/// @{
/// WriteBasicType is the name of the write function for bool, integer types,
/// and floating-point types. They all throw on error.
template
<
class
T
>
void
WriteBasicType
(
std
::
ostream
&
os
,
bool
binary
,
T
t
);
/// ReadBasicType is the name of the read function for bool, integer types,
/// and floating-point types. They all throw on error.
template
<
class
T
>
void
ReadBasicType
(
std
::
istream
&
is
,
bool
binary
,
T
*
t
);
// Declare specialization for bool.
template
<
>
void
WriteBasicType
<
bool
>
(
std
::
ostream
&
os
,
bool
binary
,
bool
b
);
template
<
>
void
ReadBasicType
<
bool
>
(
std
::
istream
&
is
,
bool
binary
,
bool
*
b
);
// Declare specializations for float and double.
template
<
>
void
WriteBasicType
<
float
>
(
std
::
ostream
&
os
,
bool
binary
,
float
f
);
template
<
>
void
WriteBasicType
<
double
>
(
std
::
ostream
&
os
,
bool
binary
,
double
f
);
template
<
>
void
ReadBasicType
<
float
>
(
std
::
istream
&
is
,
bool
binary
,
float
*
f
);
template
<
>
void
ReadBasicType
<
double
>
(
std
::
istream
&
is
,
bool
binary
,
double
*
f
);
// Define ReadBasicType that accepts an "add" parameter to add to
// the destination. Caution: if used in Read functions, be careful
// to initialize the parameters concerned to zero in the default
// constructor.
template
<
class
T
>
inline
void
ReadBasicType
(
std
::
istream
&
is
,
bool
binary
,
T
*
t
,
bool
add
)
{
if
(
!
add
)
{
ReadBasicType
(
is
,
binary
,
t
);
}
else
{
T
tmp
=
T
(
0
);
ReadBasicType
(
is
,
binary
,
&
tmp
);
*
t
+=
tmp
;
}
}
/// Function for writing STL vectors of integer types.
template
<
class
T
>
inline
void
WriteIntegerVector
(
std
::
ostream
&
os
,
bool
binary
,
const
std
::
vector
<
T
>
&
v
);
/// Function for reading STL vector of integer types.
template
<
class
T
>
inline
void
ReadIntegerVector
(
std
::
istream
&
is
,
bool
binary
,
std
::
vector
<
T
>
*
v
);
/// Function for writing STL vectors of pairs of integer types.
template
<
class
T
>
inline
void
WriteIntegerPairVector
(
std
::
ostream
&
os
,
bool
binary
,
const
std
::
vector
<
std
::
pair
<
T
,
T
>
>
&
v
);
/// Function for reading STL vector of pairs of integer types.
template
<
class
T
>
inline
void
ReadIntegerPairVector
(
std
::
istream
&
is
,
bool
binary
,
std
::
vector
<
std
::
pair
<
T
,
T
>
>
*
v
);
/// The WriteToken functions are for writing nonempty sequences of non-space
/// characters. They are not for general strings.
void
WriteToken
(
std
::
ostream
&
os
,
bool
binary
,
const
char
*
token
);
void
WriteToken
(
std
::
ostream
&
os
,
bool
binary
,
const
std
::
string
&
token
);
/// Peek consumes whitespace (if binary == false) and then returns the peek()
/// value of the stream.
int
Peek
(
std
::
istream
&
is
,
bool
binary
);
/// ReadToken gets the next token and puts it in str (exception on failure). If
/// PeekToken() had been previously called, it is possible that the stream had
/// failed to unget the starting '<' character. In this case ReadToken() returns
/// the token string without the leading '<'. You must be prepared to handle
/// this case. ExpectToken() handles this internally, and is not affected.
void
ReadToken
(
std
::
istream
&
is
,
bool
binary
,
std
::
string
*
token
);
/// PeekToken will return the first character of the next token, or -1 if end of
/// file. It's the same as Peek(), except if the first character is '<' it will
/// skip over it and will return the next character. It will attempt to unget
/// the '<' so the stream is where it was before you did PeekToken(), however,
/// this is not guaranteed (see ReadToken()).
int
PeekToken
(
std
::
istream
&
is
,
bool
binary
);
/// ExpectToken tries to read in the given token, and throws an exception
/// on failure.
void
ExpectToken
(
std
::
istream
&
is
,
bool
binary
,
const
char
*
token
);
void
ExpectToken
(
std
::
istream
&
is
,
bool
binary
,
const
std
::
string
&
token
);
/// ExpectPretty attempts to read the text in "token", but only in non-binary
/// mode. Throws exception on failure. It expects an exact match except that
/// arbitrary whitespace matches arbitrary whitespace.
void
ExpectPretty
(
std
::
istream
&
is
,
bool
binary
,
const
char
*
token
);
void
ExpectPretty
(
std
::
istream
&
is
,
bool
binary
,
const
std
::
string
&
token
);
/// @} end "addtogroup io_funcs_basic"
/// InitKaldiOutputStream initializes an opened stream for writing by writing an
/// optional binary header and modifying the floating-point precision; it will
/// typically not be called by users directly.
inline
void
InitKaldiOutputStream
(
std
::
ostream
&
os
,
bool
binary
);
/// InitKaldiInputStream initializes an opened stream for reading by detecting
/// the binary header and setting the "binary" value appropriately;
/// It will typically not be called by users directly.
inline
bool
InitKaldiInputStream
(
std
::
istream
&
is
,
bool
*
binary
);
}
// end namespace kaldi.
#endif // KALDI_BASE_IO_FUNCS_H_
runtime/core/kaldi/base/kaldi-common.h
0 → 100644
View file @
764b3a75
// base/kaldi-common.h
// Copyright 2009-2011 Microsoft Corporation
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_BASE_KALDI_COMMON_H_
#define KALDI_BASE_KALDI_COMMON_H_ 1
#include <cstddef>
#include <cstdlib>
#include <cstring> // C string stuff like strcpy
#include <string>
#include <sstream>
#include <stdexcept>
#include <cassert>
#include <vector>
#include <iostream>
#include <fstream>
#include "base/kaldi-utils.h"
#include "base/kaldi-error.h"
#include "base/kaldi-types.h"
// #include "base/io-funcs.h"
#include "base/kaldi-math.h"
// #include "base/timer.h"
#endif // KALDI_BASE_KALDI_COMMON_H_
runtime/core/kaldi/base/kaldi-error.cc
0 → 100644
View file @
764b3a75
// base/kaldi-error.cc
// Copyright 2019 LAIX (Yi Sun)
// Copyright 2019 SmartAction LLC (kkm)
// Copyright 2016 Brno University of Technology (author: Karel Vesely)
// Copyright 2009-2011 Microsoft Corporation; Lukas Burget; Ondrej Glembek
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-error.h"
#include <string>
namespace
kaldi
{
/***** GLOBAL VARIABLES FOR LOGGING *****/
int32
g_kaldi_verbose_level
=
0
;
static
std
::
string
program_name
;
// NOLINT
void
SetProgramName
(
const
char
*
basename
)
{
// Using the 'static std::string' for the program name is mostly harmless,
// because (a) Kaldi logging is undefined before main(), and (b) no stdc++
// string implementation has been found in the wild that would not be just
// an empty string when zero-initialized but not yet constructed.
program_name
=
basename
;
}
}
// namespace kaldi
runtime/core/kaldi/base/kaldi-error.h
0 → 100644
View file @
764b3a75
// base/kaldi-error.h
// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_BASE_KALDI_ERROR_H_
#define KALDI_BASE_KALDI_ERROR_H_ 1
#include "utils/log.h"
namespace
kaldi
{
#define KALDI_WARN \
google::LogMessage(__FILE__, __LINE__, google::GLOG_WARNING).stream()
#define KALDI_ERR \
google::LogMessage(__FILE__, __LINE__, google::GLOG_ERROR).stream()
#define KALDI_LOG \
google::LogMessage(__FILE__, __LINE__, google::GLOG_INFO).stream()
#define KALDI_VLOG(v) VLOG(v)
#define KALDI_ASSERT(condition) CHECK(condition)
/***** PROGRAM NAME AND VERBOSITY LEVEL *****/
/// Called by ParseOptions to set base name (no directory) of the executing
/// program. The name is printed in logging code along with every message,
/// because in our scripts, we often mix together the stderr of many programs.
/// This function is very thread-unsafe.
void
SetProgramName
(
const
char
*
basename
);
/// This is set by util/parse-options.{h,cc} if you set --verbose=? option.
/// Do not use directly, prefer {Get,Set}VerboseLevel().
extern
int32
g_kaldi_verbose_level
;
/// Get verbosity level, usually set via command line '--verbose=' switch.
inline
int32
GetVerboseLevel
()
{
return
g_kaldi_verbose_level
;
}
/// This should be rarely used, except by programs using Kaldi as library;
/// command-line programs set the verbose level automatically from ParseOptions.
inline
void
SetVerboseLevel
(
int32
i
)
{
g_kaldi_verbose_level
=
i
;
}
}
// namespace kaldi
#endif // KALDI_BASE_KALDI_ERROR_H_
runtime/core/kaldi/base/kaldi-math.cc
0 → 100644
View file @
764b3a75
// base/kaldi-math.cc
// Copyright 2009-2011 Microsoft Corporation; Yanmin Qian;
// Saarland University; Jan Silovsky
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-math.h"
#ifndef _MSC_VER
#include <stdlib.h>
#include <unistd.h>
#endif
#include <string>
#include <mutex>
namespace
kaldi
{
// These routines are tested in matrix/matrix-test.cc
int32
RoundUpToNearestPowerOfTwo
(
int32
n
)
{
KALDI_ASSERT
(
n
>
0
);
n
--
;
n
|=
n
>>
1
;
n
|=
n
>>
2
;
n
|=
n
>>
4
;
n
|=
n
>>
8
;
n
|=
n
>>
16
;
return
n
+
1
;
}
static
std
::
mutex
_RandMutex
;
int
Rand
(
struct
RandomState
*
state
)
{
#if !defined(_POSIX_THREAD_SAFE_FUNCTIONS)
// On Windows and Cygwin, just call Rand()
return
rand
();
#else
if
(
state
)
{
return
rand_r
(
&
(
state
->
seed
));
}
else
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
_RandMutex
);
return
rand
();
}
#endif
}
RandomState
::
RandomState
()
{
// we initialize it as Rand() + 27437 instead of just Rand(), because on some
// systems, e.g. at the very least Mac OSX Yosemite and later, it seems to be
// the case that rand_r when initialized with rand() will give you the exact
// same sequence of numbers that rand() will give if you keep calling rand()
// after that initial call. This can cause problems with repeated sequences.
// For example if you initialize two RandomState structs one after the other
// without calling rand() in between, they would give you the same sequence
// offset by one (if we didn't have the "+ 27437" in the code). 27437 is just
// a randomly chosen prime number.
seed
=
unsigned
(
Rand
())
+
27437
;
}
bool
WithProb
(
BaseFloat
prob
,
struct
RandomState
*
state
)
{
KALDI_ASSERT
(
prob
>=
0
&&
prob
<=
1.1
);
// prob should be <= 1.0,
// but we allow slightly larger values that could arise from roundoff in
// previous calculations.
KALDI_COMPILE_TIME_ASSERT
(
RAND_MAX
>
128
*
128
);
if
(
prob
==
0
)
{
return
false
;
}
else
if
(
prob
==
1.0
)
{
return
true
;
}
else
if
(
prob
*
RAND_MAX
<
128.0
)
{
// prob is very small but nonzero, and the "main algorithm"
// wouldn't work that well. So: with probability 1/128, we
// return WithProb (prob * 128), else return false.
if
(
Rand
(
state
)
<
RAND_MAX
/
128
)
{
// with probability 128...
// Note: we know that prob * 128.0 < 1.0, because
// we asserted RAND_MAX > 128 * 128.
return
WithProb
(
prob
*
128.0
);
}
else
{
return
false
;
}
}
else
{
return
(
Rand
(
state
)
<
((
RAND_MAX
+
static_cast
<
BaseFloat
>
(
1.0
))
*
prob
));
}
}
int32
RandInt
(
int32
min_val
,
int32
max_val
,
struct
RandomState
*
state
)
{
// This is not exact.
KALDI_ASSERT
(
max_val
>=
min_val
);
if
(
max_val
==
min_val
)
return
min_val
;
#ifdef _MSC_VER
// RAND_MAX is quite small on Windows -> may need to handle larger numbers.
if
(
RAND_MAX
>
(
max_val
-
min_val
)
*
8
)
{
// *8 to avoid large inaccuracies in probability, from the modulus...
return
min_val
+
((
unsigned
int
)
Rand
(
state
)
%
(
unsigned
int
)(
max_val
+
1
-
min_val
));
}
else
{
if
((
unsigned
int
)(
RAND_MAX
*
RAND_MAX
)
>
(
unsigned
int
)((
max_val
+
1
-
min_val
)
*
8
))
{
// *8 to avoid inaccuracies in probability, from the modulus...
return
min_val
+
(
(
unsigned
int
)(
(
Rand
(
state
)
+
RAND_MAX
*
Rand
(
state
)))
%
(
unsigned
int
)(
max_val
+
1
-
min_val
));
}
else
{
KALDI_ERR
<<
"rand_int failed because we do not support such large "
"random numbers. (Extend this function)."
;
}
}
#else
return
min_val
+
(
static_cast
<
int32
>
(
Rand
(
state
))
%
static_cast
<
int32
>
(
max_val
+
1
-
min_val
));
#endif
}
// Returns poisson-distributed random number.
// Take care: this takes time proportional
// to lambda. Faster algorithms exist but are more complex.
int32
RandPoisson
(
float
lambda
,
struct
RandomState
*
state
)
{
// Knuth's algorithm.
KALDI_ASSERT
(
lambda
>=
0
);
float
L
=
expf
(
-
lambda
),
p
=
1.0
;
int32
k
=
0
;
do
{
k
++
;
float
u
=
RandUniform
(
state
);
p
*=
u
;
}
while
(
p
>
L
);
return
k
-
1
;
}
void
RandGauss2
(
float
*
a
,
float
*
b
,
RandomState
*
state
)
{
KALDI_ASSERT
(
a
);
KALDI_ASSERT
(
b
);
float
u1
=
RandUniform
(
state
);
float
u2
=
RandUniform
(
state
);
u1
=
sqrtf
(
-
2.0
f
*
logf
(
u1
));
u2
=
2.0
f
*
M_PI
*
u2
;
*
a
=
u1
*
cosf
(
u2
);
*
b
=
u1
*
sinf
(
u2
);
}
void
RandGauss2
(
double
*
a
,
double
*
b
,
RandomState
*
state
)
{
KALDI_ASSERT
(
a
);
KALDI_ASSERT
(
b
);
float
a_float
,
b_float
;
// Just because we're using doubles doesn't mean we need super-high-quality
// random numbers, so we just use the floating-point version internally.
RandGauss2
(
&
a_float
,
&
b_float
,
state
);
*
a
=
a_float
;
*
b
=
b_float
;
}
}
// end namespace kaldi
runtime/core/kaldi/base/kaldi-math.h
0 → 100644
View file @
764b3a75
// base/kaldi-math.h
// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Yanmin Qian;
// Jan Silovsky; Saarland University
//
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_BASE_KALDI_MATH_H_
#define KALDI_BASE_KALDI_MATH_H_ 1
#ifdef _MSC_VER
#include <float.h>
#endif
#include <cmath>
#include <limits>
#include <vector>
#include "base/kaldi-types.h"
#include "base/kaldi-common.h"
#ifndef DBL_EPSILON
#define DBL_EPSILON 2.2204460492503131e-16
#endif
#ifndef FLT_EPSILON
#define FLT_EPSILON 1.19209290e-7f
#endif
#ifndef M_PI
#define M_PI 3.1415926535897932384626433832795
#endif
#ifndef M_SQRT2
#define M_SQRT2 1.4142135623730950488016887
#endif
#ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005
#endif
#ifndef M_SQRT1_2
#define M_SQRT1_2 0.7071067811865475244008443621048490
#endif
#ifndef M_LOG_2PI
#define M_LOG_2PI 1.8378770664093454835606594728112
#endif
#ifndef M_LN2
#define M_LN2 0.693147180559945309417232121458
#endif
#ifndef M_LN10
#define M_LN10 2.302585092994045684017991454684
#endif
#define KALDI_ISNAN std::isnan
#define KALDI_ISINF std::isinf
#define KALDI_ISFINITE(x) std::isfinite(x)
#if !defined(KALDI_SQR)
# define KALDI_SQR(x) ((x) * (x))
#endif
namespace
kaldi
{
#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
inline
double
Exp
(
double
x
)
{
return
exp
(
x
);
}
#ifndef KALDI_NO_EXPF
inline
float
Exp
(
float
x
)
{
return
expf
(
x
);
}
#else
inline
float
Exp
(
float
x
)
{
return
exp
(
static_cast
<
double
>
(
x
));
}
#endif // KALDI_NO_EXPF
#else
inline
double
Exp
(
double
x
)
{
return
exp
(
x
);
}
#if !defined(__INTEL_COMPILER) && _MSC_VER == 1800 && defined(_M_X64)
// Microsoft CL v18.0 buggy 64-bit implementation of
// expf() incorrectly returns -inf for exp(-inf).
inline
float
Exp
(
float
x
)
{
return
exp
(
static_cast
<
double
>
(
x
));
}
#else
inline
float
Exp
(
float
x
)
{
return
expf
(
x
);
}
#endif // !defined(__INTEL_COMPILER) && _MSC_VER == 1800 && defined(_M_X64)
#endif // !defined(_MSC_VER) || (_MSC_VER >= 1900)
inline
double
Log
(
double
x
)
{
return
log
(
x
);
}
inline
float
Log
(
float
x
)
{
return
logf
(
x
);
}
#if !defined(_MSC_VER) || (_MSC_VER >= 1700)
inline
double
Log1p
(
double
x
)
{
return
log1p
(
x
);
}
inline
float
Log1p
(
float
x
)
{
return
log1pf
(
x
);
}
#else
inline
double
Log1p
(
double
x
)
{
const
double
cutoff
=
1.0e-08
;
if
(
x
<
cutoff
)
return
x
-
0.5
*
x
*
x
;
else
return
Log
(
1.0
+
x
);
}
inline
float
Log1p
(
float
x
)
{
const
float
cutoff
=
1.0e-07
;
if
(
x
<
cutoff
)
return
x
-
0.5
*
x
*
x
;
else
return
Log
(
1.0
+
x
);
}
#endif
static
const
double
kMinLogDiffDouble
=
Log
(
DBL_EPSILON
);
// negative!
static
const
float
kMinLogDiffFloat
=
Log
(
FLT_EPSILON
);
// negative!
// -infinity
const
float
kLogZeroFloat
=
-
std
::
numeric_limits
<
float
>::
infinity
();
const
double
kLogZeroDouble
=
-
std
::
numeric_limits
<
double
>::
infinity
();
const
BaseFloat
kLogZeroBaseFloat
=
-
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
// Returns a random integer between 0 and RAND_MAX, inclusive
int
Rand
(
struct
RandomState
*
state
=
NULL
);
// State for thread-safe random number generator
struct
RandomState
{
RandomState
();
unsigned
seed
;
};
// Returns a random integer between first and last inclusive.
int32
RandInt
(
int32
first
,
int32
last
,
struct
RandomState
*
state
=
NULL
);
// Returns true with probability "prob",
bool
WithProb
(
BaseFloat
prob
,
struct
RandomState
*
state
=
NULL
);
// with 0 <= prob <= 1 [we check this].
// Internally calls Rand(). This function is carefully implemented so
// that it should work even if prob is very small.
/// Returns a random number strictly between 0 and 1.
inline
float
RandUniform
(
struct
RandomState
*
state
=
NULL
)
{
return
static_cast
<
float
>
((
Rand
(
state
)
+
1.0
)
/
(
RAND_MAX
+
2.0
));
}
inline
float
RandGauss
(
struct
RandomState
*
state
=
NULL
)
{
return
static_cast
<
float
>
(
sqrtf
(
-
2
*
Log
(
RandUniform
(
state
)))
*
cosf
(
2
*
M_PI
*
RandUniform
(
state
)));
}
// Returns poisson-distributed random number. Uses Knuth's algorithm.
// Take care: this takes time proportional
// to lambda. Faster algorithms exist but are more complex.
int32
RandPoisson
(
float
lambda
,
struct
RandomState
*
state
=
NULL
);
// Returns a pair of gaussian random numbers. Uses Box-Muller transform
void
RandGauss2
(
float
*
a
,
float
*
b
,
RandomState
*
state
=
NULL
);
void
RandGauss2
(
double
*
a
,
double
*
b
,
RandomState
*
state
=
NULL
);
// Also see Vector<float,double>::RandCategorical().
// This is a randomized pruning mechanism that preserves expectations,
// that we typically use to prune posteriors.
template
<
class
Float
>
inline
Float
RandPrune
(
Float
post
,
BaseFloat
prune_thresh
,
struct
RandomState
*
state
=
NULL
)
{
KALDI_ASSERT
(
prune_thresh
>=
0.0
);
if
(
post
==
0.0
||
std
::
abs
(
post
)
>=
prune_thresh
)
return
post
;
return
(
post
>=
0
?
1.0
:
-
1.0
)
*
(
RandUniform
(
state
)
<=
fabs
(
post
)
/
prune_thresh
?
prune_thresh
:
0.0
);
}
// returns log(exp(x) + exp(y)).
inline
double
LogAdd
(
double
x
,
double
y
)
{
double
diff
;
if
(
x
<
y
)
{
diff
=
x
-
y
;
x
=
y
;
}
else
{
diff
=
y
-
x
;
}
// diff is negative. x is now the larger one.
if
(
diff
>=
kMinLogDiffDouble
)
{
double
res
;
res
=
x
+
Log1p
(
Exp
(
diff
));
return
res
;
}
else
{
return
x
;
// return the larger one.
}
}
// returns log(exp(x) + exp(y)).
inline
float
LogAdd
(
float
x
,
float
y
)
{
float
diff
;
if
(
x
<
y
)
{
diff
=
x
-
y
;
x
=
y
;
}
else
{
diff
=
y
-
x
;
}
// diff is negative. x is now the larger one.
if
(
diff
>=
kMinLogDiffFloat
)
{
float
res
;
res
=
x
+
Log1p
(
Exp
(
diff
));
return
res
;
}
else
{
return
x
;
// return the larger one.
}
}
// returns log(exp(x) - exp(y)).
inline
double
LogSub
(
double
x
,
double
y
)
{
if
(
y
>=
x
)
{
// Throws exception if y>=x.
if
(
y
==
x
)
return
kLogZeroDouble
;
else
KALDI_ERR
<<
"Cannot subtract a larger from a smaller number."
;
}
double
diff
=
y
-
x
;
// Will be negative.
double
res
=
x
+
Log
(
1.0
-
Exp
(
diff
));
// res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision
if
(
KALDI_ISNAN
(
res
))
return
kLogZeroDouble
;
return
res
;
}
// returns log(exp(x) - exp(y)).
inline
float
LogSub
(
float
x
,
float
y
)
{
if
(
y
>=
x
)
{
// Throws exception if y>=x.
if
(
y
==
x
)
return
kLogZeroDouble
;
else
KALDI_ERR
<<
"Cannot subtract a larger from a smaller number."
;
}
float
diff
=
y
-
x
;
// Will be negative.
float
res
=
x
+
Log
(
1.0
f
-
Exp
(
diff
));
// res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision
if
(
KALDI_ISNAN
(
res
))
return
kLogZeroFloat
;
return
res
;
}
/// return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
static
inline
bool
ApproxEqual
(
float
a
,
float
b
,
float
relative_tolerance
=
0.001
)
{
// a==b handles infinities.
if
(
a
==
b
)
return
true
;
float
diff
=
std
::
abs
(
a
-
b
);
if
(
diff
==
std
::
numeric_limits
<
float
>::
infinity
()
||
diff
!=
diff
)
return
false
;
// diff is +inf or nan.
return
(
diff
<=
relative_tolerance
*
(
std
::
abs
(
a
)
+
std
::
abs
(
b
)));
}
/// assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
static
inline
void
AssertEqual
(
float
a
,
float
b
,
float
relative_tolerance
=
0.001
)
{
// a==b handles infinities.
KALDI_ASSERT
(
ApproxEqual
(
a
,
b
,
relative_tolerance
));
}
// RoundUpToNearestPowerOfTwo does the obvious thing. It crashes if n <= 0.
int32
RoundUpToNearestPowerOfTwo
(
int32
n
);
/// Returns a / b, rounding towards negative infinity in all cases.
static
inline
int32
DivideRoundingDown
(
int32
a
,
int32
b
)
{
KALDI_ASSERT
(
b
!=
0
);
if
(
a
*
b
>=
0
)
return
a
/
b
;
else
if
(
a
<
0
)
return
(
a
-
b
+
1
)
/
b
;
else
return
(
a
-
b
-
1
)
/
b
;
}
template
<
class
I
>
I
Gcd
(
I
m
,
I
n
)
{
if
(
m
==
0
||
n
==
0
)
{
if
(
m
==
0
&&
n
==
0
)
{
// gcd not defined, as all integers are divisors.
KALDI_ERR
<<
"Undefined GCD since m = 0, n = 0."
;
}
return
(
m
==
0
?
(
n
>
0
?
n
:
-
n
)
:
(
m
>
0
?
m
:
-
m
));
// return absolute value of whichever is nonzero
}
// could use compile-time assertion
// but involves messing with complex template stuff.
KALDI_ASSERT
(
std
::
numeric_limits
<
I
>::
is_integer
);
while
(
1
)
{
m
%=
n
;
if
(
m
==
0
)
return
(
n
>
0
?
n
:
-
n
);
n
%=
m
;
if
(
n
==
0
)
return
(
m
>
0
?
m
:
-
m
);
}
}
/// Returns the least common multiple of two integers. Will
/// crash unless the inputs are positive.
template
<
class
I
>
I
Lcm
(
I
m
,
I
n
)
{
KALDI_ASSERT
(
m
>
0
&&
n
>
0
);
I
gcd
=
Gcd
(
m
,
n
);
return
gcd
*
(
m
/
gcd
)
*
(
n
/
gcd
);
}
template
<
class
I
>
void
Factorize
(
I
m
,
std
::
vector
<
I
>
*
factors
)
{
// Splits a number into its prime factors, in sorted order from
// least to greatest, with duplication. A very inefficient
// algorithm, which is mainly intended for use in the
// mixed-radix FFT computation (where we assume most factors
// are small).
KALDI_ASSERT
(
factors
!=
NULL
);
KALDI_ASSERT
(
m
>=
1
);
// Doesn't work for zero or negative numbers.
factors
->
clear
();
I
small_factors
[
10
]
=
{
2
,
3
,
5
,
7
,
11
,
13
,
17
,
19
,
23
,
29
};
// First try small factors.
for
(
I
i
=
0
;
i
<
10
;
i
++
)
{
if
(
m
==
1
)
return
;
// We're done.
while
(
m
%
small_factors
[
i
]
==
0
)
{
m
/=
small_factors
[
i
];
factors
->
push_back
(
small_factors
[
i
]);
}
}
// Next try all odd numbers starting from 31.
for
(
I
j
=
31
;;
j
+=
2
)
{
if
(
m
==
1
)
return
;
while
(
m
%
j
==
0
)
{
m
/=
j
;
factors
->
push_back
(
j
);
}
}
}
inline
double
Hypot
(
double
x
,
double
y
)
{
return
hypot
(
x
,
y
);
}
inline
float
Hypot
(
float
x
,
float
y
)
{
return
hypotf
(
x
,
y
);
}
}
// namespace kaldi
#endif // KALDI_BASE_KALDI_MATH_H_
runtime/core/kaldi/base/kaldi-types.h
0 → 100644
View file @
764b3a75
// base/kaldi-types.h
// Copyright 2009-2011 Microsoft Corporation; Saarland University;
// Jan Silovsky; Yanmin Qian
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_BASE_KALDI_TYPES_H_
#define KALDI_BASE_KALDI_TYPES_H_ 1
namespace
kaldi
{
// TYPEDEFS ..................................................................
#if (KALDI_DOUBLEPRECISION != 0)
typedef
double
BaseFloat
;
#else
typedef
float
BaseFloat
;
#endif
}
#ifdef _MSC_VER
#include <basetsd.h>
#define ssize_t SSIZE_T
#endif
// we can do this a different way if some platform
// we find in the future lacks stdint.h
#include <stdint.h>
// for discussion on what to do if you need compile kaldi
// without OpenFST, see the bottom of this this file
#include <fst/types.h>
namespace
kaldi
{
using
::
int16
;
using
::
int32
;
using
::
int64
;
using
::
uint16
;
using
::
uint32
;
using
::
uint64
;
typedef
float
float32
;
typedef
double
double64
;
}
// end namespace kaldi
// In a theoretical case you decide compile Kaldi without the OpenFST
// comment the previous namespace statement and uncomment the following
/*
namespace kaldi {
typedef int8_t int8;
typedef int16_t int16;
typedef int32_t int32;
typedef int64_t int64;
typedef uint8_t uint8;
typedef uint16_t uint16;
typedef uint32_t uint32;
typedef uint64_t uint64;
typedef float float32;
typedef double double64;
} // end namespace kaldi
*/
#endif // KALDI_BASE_KALDI_TYPES_H_
runtime/core/kaldi/base/kaldi-utils.h
0 → 100644
View file @
764b3a75
// base/kaldi-utils.h
// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation;
// Saarland University; Karel Vesely; Yanmin Qian
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_BASE_KALDI_UTILS_H_
#define KALDI_BASE_KALDI_UTILS_H_ 1
#if defined(_MSC_VER)
# define WIN32_LEAN_AND_MEAN
# define NOMINMAX
# include <windows.h>
#endif
#ifdef _MSC_VER
#include <stdio.h>
#define unlink _unlink
#else
#include <unistd.h>
#endif
#include <limits>
#include <string>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4056 4305 4800 4267 4996 4756 4661)
#if _MSC_VER < 1400
#define __restrict__
#else
#define __restrict__ __restrict
#endif
#endif
#if defined(_MSC_VER)
# define KALDI_MEMALIGN(align, size, pp_orig) \
(*(pp_orig) = _aligned_malloc(size, align))
# define KALDI_MEMALIGN_FREE(x) _aligned_free(x)
#elif defined(__CYGWIN__)
# define KALDI_MEMALIGN(align, size, pp_orig) \
(*(pp_orig) = aligned_alloc(align, size))
# define KALDI_MEMALIGN_FREE(x) free(x)
#else
# define KALDI_MEMALIGN(align, size, pp_orig) \
(!posix_memalign(pp_orig, align, size) ? *(pp_orig) : NULL)
# define KALDI_MEMALIGN_FREE(x) free(x)
#endif
#ifdef __ICC
#pragma warning(disable: 383) // ICPC remark we don't want.
#pragma warning(disable: 810) // ICPC remark we don't want.
#pragma warning(disable: 981) // ICPC remark we don't want.
#pragma warning(disable: 1418) // ICPC remark we don't want.
#pragma warning(disable: 444) // ICPC remark we don't want.
#pragma warning(disable: 869) // ICPC remark we don't want.
#pragma warning(disable: 1287) // ICPC remark we don't want.
#pragma warning(disable: 279) // ICPC remark we don't want.
#pragma warning(disable: 981) // ICPC remark we don't want.
#endif
namespace
kaldi
{
// CharToString prints the character in a human-readable form, for debugging.
std
::
string
CharToString
(
const
char
&
c
);
inline
int
MachineIsLittleEndian
()
{
int
check
=
1
;
return
(
*
reinterpret_cast
<
char
*>
(
&
check
)
!=
0
);
}
// This function kaldi::Sleep() provides a portable way
// to sleep for a possibly fractional
// number of seconds. On Windows it's only accurate to microseconds.
void
Sleep
(
float
seconds
);
}
// namespace kaldi
#define KALDI_SWAP8(a) do { \
int t = (reinterpret_cast<char*>(&a))[0];\
(reinterpret_cast<char*>(&a))[0]=(reinterpret_cast<char*>(&a))[7];\
(reinterpret_cast<char*>(&a))[7] = t;\
t = (reinterpret_cast<char*>(&a))[1];\
(reinterpret_cast<char*>(&a))[1]=(reinterpret_cast<char*>(&a))[6];\
(reinterpret_cast<char*>(&a))[6] = t;\
t = (reinterpret_cast<char*>(&a))[2];\
(reinterpret_cast<char*>(&a))[2]=(reinterpret_cast<char*>(&a))[5];\
(reinterpret_cast<char*>(&a))[5] = t;\
t = (reinterpret_cast<char*>(&a))[3];\
(reinterpret_cast<char*>(&a))[3]=(reinterpret_cast<char*>(&a))[4];\
(reinterpret_cast<char*>(&a))[4] = t;} while (0)
#define KALDI_SWAP4(a) do { \
int t = (reinterpret_cast<char*>(&a))[0];\
(reinterpret_cast<char*>(&a))[0]=(reinterpret_cast<char*>(&a))[3];\
(reinterpret_cast<char*>(&a))[3] = t;\
t = (reinterpret_cast<char*>(&a))[1];\
(reinterpret_cast<char*>(&a))[1]=(reinterpret_cast<char*>(&a))[2];\
(reinterpret_cast<char*>(&a))[2]=t;} while (0)
#define KALDI_SWAP2(a) do { \
int t = (reinterpret_cast<char*>(&a))[0];\
(reinterpret_cast<char*>(&a))[0]=(reinterpret_cast<char*>(&a))[1];\
(reinterpret_cast<char*>(&a))[1] = t;} while (0)
// Makes copy constructor and operator= private.
#define KALDI_DISALLOW_COPY_AND_ASSIGN(type) \
type(const type&); \
void operator = (const type&)
template
<
bool
B
>
class
KaldiCompileTimeAssert
{
};
template
<
>
class
KaldiCompileTimeAssert
<
true
>
{
public:
static
inline
void
Check
()
{
}
};
#define KALDI_COMPILE_TIME_ASSERT(b) KaldiCompileTimeAssert<(b)>::Check()
#define KALDI_ASSERT_IS_INTEGER_TYPE(I) \
KaldiCompileTimeAssert<std::numeric_limits<I>::is_specialized \
&& std::numeric_limits<I>::is_integer>::Check()
#define KALDI_ASSERT_IS_FLOATING_TYPE(F) \
KaldiCompileTimeAssert<std::numeric_limits<F>::is_specialized \
&& !std::numeric_limits<F>::is_integer>::Check()
#if defined(_MSC_VER)
#define KALDI_STRCASECMP _stricmp
#elif defined(__CYGWIN__)
#include <strings.h>
#define KALDI_STRCASECMP strcasecmp
#else
#define KALDI_STRCASECMP strcasecmp
#endif
#ifdef _MSC_VER
# define KALDI_STRTOLL(cur_cstr, end_cstr) _strtoi64(cur_cstr, end_cstr, 10);
#else
# define KALDI_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10);
#endif
#endif // KALDI_BASE_KALDI_UTILS_H_
runtime/core/kaldi/decoder/lattice-faster-decoder.cc
0 → 100644
View file @
764b3a75
// decoder/lattice-faster-decoder.cc
// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
// 2013-2018 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2018 Zhehuai Chen
// 2021 Binbin Zhang, Zhendong Peng
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <unordered_set>
#include "decoder/lattice-faster-decoder.h"
// #include "lat/lattice-functions.h"
namespace
kaldi
{
// instantiate this class once for each thing you have to decode.
template
<
typename
FST
,
typename
Token
>
LatticeFasterDecoderTpl
<
FST
,
Token
>::
LatticeFasterDecoderTpl
(
const
FST
&
fst
,
const
LatticeFasterDecoderConfig
&
config
,
const
std
::
shared_ptr
<
wenet
::
ContextGraph
>
&
context_graph
)
:
fst_
(
&
fst
),
delete_fst_
(
false
),
config_
(
config
),
num_toks_
(
0
),
context_graph_
(
context_graph
)
{
config
.
Check
();
toks_
.
SetSize
(
1000
);
// just so on the first frame we do something reasonable.
}
template
<
typename
FST
,
typename
Token
>
LatticeFasterDecoderTpl
<
FST
,
Token
>::
LatticeFasterDecoderTpl
(
const
LatticeFasterDecoderConfig
&
config
,
FST
*
fst
)
:
fst_
(
fst
),
delete_fst_
(
true
),
config_
(
config
),
num_toks_
(
0
)
{
config
.
Check
();
toks_
.
SetSize
(
1000
);
// just so on the first frame we do something reasonable.
}
template
<
typename
FST
,
typename
Token
>
LatticeFasterDecoderTpl
<
FST
,
Token
>::~
LatticeFasterDecoderTpl
()
{
DeleteElems
(
toks_
.
Clear
());
ClearActiveTokens
();
if
(
delete_fst_
)
delete
fst_
;
}
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
InitDecoding
()
{
// clean up from last time:
DeleteElems
(
toks_
.
Clear
());
cost_offsets_
.
clear
();
ClearActiveTokens
();
warned_
=
false
;
num_toks_
=
0
;
decoding_finalized_
=
false
;
final_costs_
.
clear
();
StateId
start_state
=
fst_
->
Start
();
KALDI_ASSERT
(
start_state
!=
fst
::
kNoStateId
);
active_toks_
.
resize
(
1
);
Token
*
start_tok
=
new
Token
(
0.0
,
0.0
,
NULL
,
NULL
,
NULL
);
active_toks_
[
0
].
toks
=
start_tok
;
toks_
.
Insert
(
start_state
,
start_tok
);
num_toks_
++
;
ProcessNonemitting
(
config_
.
beam
);
}
// Returns true if any kind of traceback is available (not necessarily from
// a final state). It should only very rarely return false; this indicates
// an unusual search error.
template
<
typename
FST
,
typename
Token
>
bool
LatticeFasterDecoderTpl
<
FST
,
Token
>::
Decode
(
DecodableInterface
*
decodable
)
{
InitDecoding
();
// We use 1-based indexing for frames in this decoder (if you view it in
// terms of features), but note that the decodable object uses zero-based
// numbering, which we have to correct for when we call it.
AdvanceDecoding
(
decodable
);
FinalizeDecoding
();
// Returns true if we have any kind of traceback available (not necessarily
// to the end state; query ReachedFinal() for that).
return
!
active_toks_
.
empty
()
&&
active_toks_
.
back
().
toks
!=
NULL
;
}
// Outputs an FST corresponding to the single best path through the lattice.
template
<
typename
FST
,
typename
Token
>
bool
LatticeFasterDecoderTpl
<
FST
,
Token
>::
GetBestPath
(
Lattice
*
olat
,
bool
use_final_probs
)
const
{
Lattice
raw_lat
;
GetRawLattice
(
&
raw_lat
,
use_final_probs
);
ShortestPath
(
raw_lat
,
olat
);
return
(
olat
->
NumStates
()
!=
0
);
}
// Outputs an FST corresponding to the raw, state-level lattice
template
<
typename
FST
,
typename
Token
>
bool
LatticeFasterDecoderTpl
<
FST
,
Token
>::
GetRawLattice
(
Lattice
*
ofst
,
bool
use_final_probs
)
const
{
typedef
LatticeArc
Arc
;
typedef
Arc
::
StateId
StateId
;
typedef
Arc
::
Weight
Weight
;
typedef
Arc
::
Label
Label
;
// Note: you can't use the old interface (Decode()) if you want to
// get the lattice with use_final_probs = false. You'd have to do
// InitDecoding() and then AdvanceDecoding().
if
(
decoding_finalized_
&&
!
use_final_probs
)
KALDI_ERR
<<
"You cannot call FinalizeDecoding() and then call "
<<
"GetRawLattice() with use_final_probs == false"
;
unordered_map
<
Token
*
,
BaseFloat
>
final_costs_local
;
const
unordered_map
<
Token
*
,
BaseFloat
>
&
final_costs
=
(
decoding_finalized_
?
final_costs_
:
final_costs_local
);
if
(
!
decoding_finalized_
&&
use_final_probs
)
ComputeFinalCosts
(
&
final_costs_local
,
NULL
,
NULL
);
ofst
->
DeleteStates
();
// num-frames plus one (since frames are one-based, and we have
// an extra frame for the start-state).
int32
num_frames
=
active_toks_
.
size
()
-
1
;
KALDI_ASSERT
(
num_frames
>
0
);
const
int32
bucket_count
=
num_toks_
/
2
+
3
;
unordered_map
<
Token
*
,
StateId
>
tok_map
(
bucket_count
);
// First create all states.
std
::
vector
<
Token
*>
token_list
;
for
(
int32
f
=
0
;
f
<=
num_frames
;
f
++
)
{
if
(
active_toks_
[
f
].
toks
==
NULL
)
{
KALDI_WARN
<<
"GetRawLattice: no tokens active on frame "
<<
f
<<
": not producing lattice.
\n
"
;
return
false
;
}
TopSortTokens
(
active_toks_
[
f
].
toks
,
&
token_list
);
for
(
size_t
i
=
0
;
i
<
token_list
.
size
();
i
++
)
if
(
token_list
[
i
]
!=
NULL
)
tok_map
[
token_list
[
i
]]
=
ofst
->
AddState
();
}
// The next statement sets the start state of the output FST. Because we
// topologically sorted the tokens, state zero must be the start-state.
ofst
->
SetStart
(
0
);
KALDI_VLOG
(
4
)
<<
"init:"
<<
num_toks_
/
2
+
3
<<
" buckets:"
<<
tok_map
.
bucket_count
()
<<
" load:"
<<
tok_map
.
load_factor
()
<<
" max:"
<<
tok_map
.
max_load_factor
();
// Now create all arcs.
for
(
int32
f
=
0
;
f
<=
num_frames
;
f
++
)
{
for
(
Token
*
tok
=
active_toks_
[
f
].
toks
;
tok
!=
NULL
;
tok
=
tok
->
next
)
{
StateId
cur_state
=
tok_map
[
tok
];
for
(
ForwardLinkT
*
l
=
tok
->
links
;
l
!=
NULL
;
l
=
l
->
next
)
{
typename
unordered_map
<
Token
*
,
StateId
>::
const_iterator
iter
=
tok_map
.
find
(
l
->
next_tok
);
StateId
nextstate
=
iter
->
second
;
KALDI_ASSERT
(
iter
!=
tok_map
.
end
());
BaseFloat
cost_offset
=
0.0
;
if
(
l
->
ilabel
!=
0
)
{
// emitting..
KALDI_ASSERT
(
f
>=
0
&&
f
<
cost_offsets_
.
size
());
cost_offset
=
cost_offsets_
[
f
];
}
StateId
state
=
cur_state
;
if
(
l
->
is_start_boundary
)
{
StateId
tmp
=
ofst
->
AddState
();
Arc
arc
(
0
,
context_graph_
->
start_tag_id
(),
Weight
(
0
,
0
),
tmp
);
ofst
->
AddArc
(
state
,
arc
);
state
=
tmp
;
}
if
(
l
->
is_end_boundary
)
{
StateId
tmp
=
ofst
->
AddState
();
Arc
arc
(
0
,
context_graph_
->
end_tag_id
(),
Weight
(
0
,
0
),
nextstate
);
ofst
->
AddArc
(
tmp
,
arc
);
nextstate
=
tmp
;
}
Arc
arc
(
l
->
ilabel
,
l
->
olabel
,
Weight
(
l
->
graph_cost
,
l
->
acoustic_cost
-
cost_offset
),
nextstate
);
ofst
->
AddArc
(
state
,
arc
);
}
if
(
f
==
num_frames
)
{
if
(
use_final_probs
&&
!
final_costs
.
empty
())
{
typename
unordered_map
<
Token
*
,
BaseFloat
>::
const_iterator
iter
=
final_costs
.
find
(
tok
);
if
(
iter
!=
final_costs
.
end
())
ofst
->
SetFinal
(
cur_state
,
LatticeWeight
(
iter
->
second
,
0
));
}
else
{
ofst
->
SetFinal
(
cur_state
,
LatticeWeight
::
One
());
}
}
}
}
fst
::
TopSort
(
ofst
);
return
(
ofst
->
NumStates
()
>
0
);
}
// This function is now deprecated, since now we do determinization from outside
// the LatticeFasterDecoder class. Outputs an FST corresponding to the
// lattice-determinized lattice (one path per word sequence).
template
<
typename
FST
,
typename
Token
>
bool
LatticeFasterDecoderTpl
<
FST
,
Token
>::
GetLattice
(
CompactLattice
*
ofst
,
bool
use_final_probs
)
const
{
Lattice
raw_fst
;
GetRawLattice
(
&
raw_fst
,
use_final_probs
);
Invert
(
&
raw_fst
);
// make it so word labels are on the input.
// (in phase where we get backward-costs).
fst
::
ILabelCompare
<
LatticeArc
>
ilabel_comp
;
ArcSort
(
&
raw_fst
,
ilabel_comp
);
// sort on ilabel; makes
// lattice-determinization more efficient.
fst
::
DeterminizeLatticePrunedOptions
lat_opts
;
lat_opts
.
max_mem
=
config_
.
det_opts
.
max_mem
;
DeterminizeLatticePruned
(
raw_fst
,
config_
.
lattice_beam
,
ofst
,
lat_opts
);
raw_fst
.
DeleteStates
();
// Free memory-- raw_fst no longer needed.
Connect
(
ofst
);
// Remove unreachable states... there might be
// a small number of these, in some cases.
// Note: if something went wrong and the raw lattice was empty,
// we should still get to this point in the code without warnings or failures.
return
(
ofst
->
NumStates
()
!=
0
);
}
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
PossiblyResizeHash
(
size_t
num_toks
)
{
size_t
new_sz
=
static_cast
<
size_t
>
(
static_cast
<
BaseFloat
>
(
num_toks
)
*
config_
.
hash_ratio
);
if
(
new_sz
>
toks_
.
Size
())
{
toks_
.
SetSize
(
new_sz
);
}
}
/*
A note on the definition of extra_cost.
extra_cost is used in pruning tokens, to save memory.
extra_cost can be thought of as a beta (backward) cost assuming
we had set the betas on currently-active tokens to all be the negative
of the alphas for those tokens. (So all currently active tokens would
be on (tied) best paths).
We can use the extra_cost to accurately prune away tokens that we know will
never appear in the lattice. If the extra_cost is greater than the desired
lattice beam, the token would provably never appear in the lattice, so we can
prune away the token.
(Note: we don't update all the extra_costs every time we update a frame; we
only do it every 'config_.prune_interval' frames).
*/
// FindOrAddToken either locates a token in hash of toks_,
// or if necessary inserts a new, empty token (i.e. with no forward links)
// for the current frame. [note: it's inserted if necessary into hash toks_
// and also into the singly linked list of tokens active on this frame
// (whose head is at active_toks_[frame]).
template
<
typename
FST
,
typename
Token
>
inline
typename
LatticeFasterDecoderTpl
<
FST
,
Token
>::
Elem
*
LatticeFasterDecoderTpl
<
FST
,
Token
>::
FindOrAddToken
(
StateId
state
,
int32
frame_plus_one
,
BaseFloat
tot_cost
,
Token
*
backpointer
,
bool
*
changed
)
{
// Returns the Token pointer. Sets "changed" (if non-NULL) to true
// if the token was newly created or the cost changed.
KALDI_ASSERT
(
frame_plus_one
<
active_toks_
.
size
());
Token
*&
toks
=
active_toks_
[
frame_plus_one
].
toks
;
Elem
*
e_found
=
toks_
.
Insert
(
state
,
NULL
);
if
(
e_found
->
val
==
NULL
)
{
// no such token presently.
const
BaseFloat
extra_cost
=
0.0
;
// tokens on the currently final frame have zero extra_cost
// as any of them could end up
// on the winning path.
Token
*
new_tok
=
new
Token
(
tot_cost
,
extra_cost
,
NULL
,
toks
,
backpointer
);
// NULL: no forward links yet
toks
=
new_tok
;
num_toks_
++
;
e_found
->
val
=
new_tok
;
if
(
changed
)
*
changed
=
true
;
return
e_found
;
}
else
{
Token
*
tok
=
e_found
->
val
;
// There is an existing Token for this state.
if
(
tok
->
tot_cost
>
tot_cost
)
{
// replace old token
tok
->
tot_cost
=
tot_cost
;
// SetBackpointer() just does tok->backpointer = backpointer in
// the case where Token == BackpointerToken, else nothing.
tok
->
SetBackpointer
(
backpointer
);
// we don't allocate a new token, the old stays linked in active_toks_
// we only replace the tot_cost
// in the current frame, there are no forward links (and no extra_cost)
// only in ProcessNonemitting we have to delete forward links
// in case we visit a state for the second time
// those forward links, that lead to this replaced token before:
// they remain and will hopefully be pruned later (PruneForwardLinks...)
if
(
changed
)
*
changed
=
true
;
}
else
{
if
(
changed
)
*
changed
=
false
;
}
return
e_found
;
}
}
// prunes outgoing links for all tokens in active_toks_[frame]
// it's called by PruneActiveTokens
// all links, that have link_extra_cost > lattice_beam are pruned
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
PruneForwardLinks
(
int32
frame_plus_one
,
bool
*
extra_costs_changed
,
bool
*
links_pruned
,
BaseFloat
delta
)
{
// delta is the amount by which the extra_costs must change
// If delta is larger, we'll tend to go back less far
// toward the beginning of the file.
// extra_costs_changed is set to true if extra_cost was changed for any token
// links_pruned is set to true if any link in any token was pruned
*
extra_costs_changed
=
false
;
*
links_pruned
=
false
;
KALDI_ASSERT
(
frame_plus_one
>=
0
&&
frame_plus_one
<
active_toks_
.
size
());
if
(
active_toks_
[
frame_plus_one
].
toks
==
NULL
)
{
// empty list; should not happen.
if
(
!
warned_
)
{
KALDI_WARN
<<
"No tokens alive [doing pruning].. warning first "
"time only for each utterance
\n
"
;
warned_
=
true
;
}
}
// We have to iterate until there is no more change, because the links
// are not guaranteed to be in topological order.
bool
changed
=
true
;
// difference new minus old extra cost >= delta ?
while
(
changed
)
{
changed
=
false
;
for
(
Token
*
tok
=
active_toks_
[
frame_plus_one
].
toks
;
tok
!=
NULL
;
tok
=
tok
->
next
)
{
ForwardLinkT
*
link
,
*
prev_link
=
NULL
;
// will recompute tok_extra_cost for tok.
BaseFloat
tok_extra_cost
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
// tok_extra_cost is the best (min) of link_extra_cost of outgoing links
for
(
link
=
tok
->
links
;
link
!=
NULL
;)
{
// See if we need to excise this link...
Token
*
next_tok
=
link
->
next_tok
;
BaseFloat
link_extra_cost
=
next_tok
->
extra_cost
+
((
tok
->
tot_cost
+
link
->
acoustic_cost
+
link
->
graph_cost
)
-
next_tok
->
tot_cost
);
// difference in brackets is >= 0
// link_exta_cost is the difference in score between the best paths
// through link source state and through link destination state
KALDI_ASSERT
(
link_extra_cost
==
link_extra_cost
);
// check for NaN
// the graph_cost contatins the context score
// if it's the score of the backoff arc, it should be removed.
if
(
link
->
context_score
<
0
)
{
link_extra_cost
+=
link
->
context_score
;
}
if
(
link_extra_cost
>
config_
.
lattice_beam
)
{
// excise link
ForwardLinkT
*
next_link
=
link
->
next
;
if
(
prev_link
!=
NULL
)
prev_link
->
next
=
next_link
;
else
tok
->
links
=
next_link
;
delete
link
;
link
=
next_link
;
// advance link but leave prev_link the same.
*
links_pruned
=
true
;
}
else
{
// keep the link and update the tok_extra_cost if needed.
if
(
link_extra_cost
<
0.0
)
{
// this is just a precaution.
// if (link_extra_cost < -0.01)
// KALDI_WARN << "Negative extra_cost: " << link_extra_cost;
link_extra_cost
=
0.0
;
}
if
(
link_extra_cost
<
tok_extra_cost
)
tok_extra_cost
=
link_extra_cost
;
prev_link
=
link
;
// move to next link
link
=
link
->
next
;
}
}
// for all outgoing links
if
(
fabs
(
tok_extra_cost
-
tok
->
extra_cost
)
>
delta
)
changed
=
true
;
// difference new minus old is bigger than delta
tok
->
extra_cost
=
tok_extra_cost
;
// will be +infinity or <= lattice_beam_.
// infinity indicates, that no forward link survived pruning
}
// for all Token on active_toks_[frame]
if
(
changed
)
*
extra_costs_changed
=
true
;
// Note: it's theoretically possible that aggressive compiler
// optimizations could cause an infinite loop here for small delta and
// high-dynamic-range scores.
}
// while changed
}
// PruneForwardLinksFinal is a version of PruneForwardLinks that we call
// on the final frame. If there are final tokens active, it uses
// the final-probs for pruning, otherwise it treats all tokens as final.
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
PruneForwardLinksFinal
()
{
KALDI_ASSERT
(
!
active_toks_
.
empty
());
int32
frame_plus_one
=
active_toks_
.
size
()
-
1
;
if
(
active_toks_
[
frame_plus_one
].
toks
==
NULL
)
// empty list; should not happen.
KALDI_WARN
<<
"No tokens alive at end of file"
;
typedef
typename
unordered_map
<
Token
*
,
BaseFloat
>::
const_iterator
IterType
;
ComputeFinalCosts
(
&
final_costs_
,
&
final_relative_cost_
,
&
final_best_cost_
);
decoding_finalized_
=
true
;
// We call DeleteElems() as a nicety, not because it's really necessary;
// otherwise there would be a time, after calling PruneTokensForFrame() on the
// final frame, when toks_.GetList() or toks_.Clear() would contain pointers
// to nonexistent tokens.
DeleteElems
(
toks_
.
Clear
());
// Now go through tokens on this frame, pruning forward links... may have to
// iterate a few times until there is no more change, because the list is not
// in topological order. This is a modified version of the code in
// PruneForwardLinks, but here we also take account of the final-probs.
bool
changed
=
true
;
BaseFloat
delta
=
1.0e-05
;
while
(
changed
)
{
changed
=
false
;
for
(
Token
*
tok
=
active_toks_
[
frame_plus_one
].
toks
;
tok
!=
NULL
;
tok
=
tok
->
next
)
{
ForwardLinkT
*
link
,
*
prev_link
=
NULL
;
// will recompute tok_extra_cost. It has a term in it that corresponds
// to the "final-prob", so instead of initializing tok_extra_cost to
// infinity below we set it to the difference between the
// (score+final_prob) of this token, and the best such (score+final_prob).
BaseFloat
final_cost
;
if
(
final_costs_
.
empty
())
{
final_cost
=
0.0
;
}
else
{
IterType
iter
=
final_costs_
.
find
(
tok
);
if
(
iter
!=
final_costs_
.
end
())
final_cost
=
iter
->
second
;
else
final_cost
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
}
BaseFloat
tok_extra_cost
=
tok
->
tot_cost
+
final_cost
-
final_best_cost_
;
// tok_extra_cost will be a "min" over either directly being final, or
// being indirectly final through other links, and the loop below may
// decrease its value:
for
(
link
=
tok
->
links
;
link
!=
NULL
;)
{
// See if we need to excise this link...
Token
*
next_tok
=
link
->
next_tok
;
BaseFloat
link_extra_cost
=
next_tok
->
extra_cost
+
((
tok
->
tot_cost
+
link
->
acoustic_cost
+
link
->
graph_cost
)
-
next_tok
->
tot_cost
);
if
(
link_extra_cost
>
config_
.
lattice_beam
)
{
// excise link
ForwardLinkT
*
next_link
=
link
->
next
;
if
(
prev_link
!=
NULL
)
prev_link
->
next
=
next_link
;
else
tok
->
links
=
next_link
;
delete
link
;
link
=
next_link
;
// advance link but leave prev_link the same.
}
else
{
// keep the link and update the tok_extra_cost if needed.
if
(
link_extra_cost
<
0.0
)
{
// this is just a precaution.
// if (link_extra_cost < -0.01)
// KALDI_WARN << "Negative extra_cost: " << link_extra_cost;
link_extra_cost
=
0.0
;
}
if
(
link_extra_cost
<
tok_extra_cost
)
tok_extra_cost
=
link_extra_cost
;
prev_link
=
link
;
link
=
link
->
next
;
}
}
// prune away tokens worse than lattice_beam above best path. This step
// was not necessary in the non-final case because then, this case
// showed up as having no forward links. Here, the tok_extra_cost has
// an extra component relating to the final-prob.
if
(
tok_extra_cost
>
config_
.
lattice_beam
)
tok_extra_cost
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
// to be pruned in PruneTokensForFrame
if
(
!
ApproxEqual
(
tok
->
extra_cost
,
tok_extra_cost
,
delta
))
changed
=
true
;
tok
->
extra_cost
=
tok_extra_cost
;
// will be +infinity or <= lattice_beam_.
}
}
// while changed
}
template
<
typename
FST
,
typename
Token
>
BaseFloat
LatticeFasterDecoderTpl
<
FST
,
Token
>::
FinalRelativeCost
()
const
{
if
(
!
decoding_finalized_
)
{
BaseFloat
relative_cost
;
ComputeFinalCosts
(
NULL
,
&
relative_cost
,
NULL
);
return
relative_cost
;
}
else
{
// we're not allowed to call that function if FinalizeDecoding() has
// been called; return a cached value.
return
final_relative_cost_
;
}
}
// Prune away any tokens on this frame that have no forward links.
// [we don't do this in PruneForwardLinks because it would give us
// a problem with dangling pointers].
// It's called by PruneActiveTokens if any forward links have been pruned
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
PruneTokensForFrame
(
int32
frame_plus_one
)
{
KALDI_ASSERT
(
frame_plus_one
>=
0
&&
frame_plus_one
<
active_toks_
.
size
());
Token
*&
toks
=
active_toks_
[
frame_plus_one
].
toks
;
if
(
toks
==
NULL
)
KALDI_WARN
<<
"No tokens alive [doing pruning]"
;
Token
*
tok
,
*
next_tok
,
*
prev_tok
=
NULL
;
for
(
tok
=
toks
;
tok
!=
NULL
;
tok
=
next_tok
)
{
next_tok
=
tok
->
next
;
if
(
tok
->
extra_cost
==
std
::
numeric_limits
<
BaseFloat
>::
infinity
())
{
// token is unreachable from end of graph; (no forward links survived)
// excise tok from list and delete tok.
if
(
prev_tok
!=
NULL
)
prev_tok
->
next
=
tok
->
next
;
else
toks
=
tok
->
next
;
delete
tok
;
num_toks_
--
;
}
else
{
// fetch next Token
prev_tok
=
tok
;
}
}
}
// Go backwards through still-alive tokens, pruning them, starting not from
// the current frame (where we want to keep all tokens) but from the frame
// before that. We go backwards through the frames and stop when we reach a
// point where the delta-costs are not changing (and the delta controls when we
// consider a cost to have "not changed").
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
PruneActiveTokens
(
BaseFloat
delta
)
{
int32
cur_frame_plus_one
=
NumFramesDecoded
();
int32
num_toks_begin
=
num_toks_
;
// The index "f" below represents a "frame plus one", i.e. you'd have to
// subtract one to get the corresponding index for the decodable object.
for
(
int32
f
=
cur_frame_plus_one
-
1
;
f
>=
0
;
f
--
)
{
// Reason why we need to prune forward links in this situation:
// (1) we have never pruned them (new TokenList)
// (2) we have not yet pruned the forward links to the next f,
// after any of those tokens have changed their extra_cost.
if
(
active_toks_
[
f
].
must_prune_forward_links
)
{
bool
extra_costs_changed
=
false
,
links_pruned
=
false
;
PruneForwardLinks
(
f
,
&
extra_costs_changed
,
&
links_pruned
,
delta
);
if
(
extra_costs_changed
&&
f
>
0
)
// any token has changed extra_cost
active_toks_
[
f
-
1
].
must_prune_forward_links
=
true
;
if
(
links_pruned
)
// any link was pruned
active_toks_
[
f
].
must_prune_tokens
=
true
;
active_toks_
[
f
].
must_prune_forward_links
=
false
;
// job done
}
if
(
f
+
1
<
cur_frame_plus_one
&&
// except for last f (no forward links)
active_toks_
[
f
+
1
].
must_prune_tokens
)
{
PruneTokensForFrame
(
f
+
1
);
active_toks_
[
f
+
1
].
must_prune_tokens
=
false
;
}
}
KALDI_VLOG
(
4
)
<<
"PruneActiveTokens: pruned tokens from "
<<
num_toks_begin
<<
" to "
<<
num_toks_
;
}
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
ComputeFinalCosts
(
unordered_map
<
Token
*
,
BaseFloat
>
*
final_costs
,
BaseFloat
*
final_relative_cost
,
BaseFloat
*
final_best_cost
)
const
{
KALDI_ASSERT
(
!
decoding_finalized_
);
if
(
final_costs
!=
NULL
)
final_costs
->
clear
();
const
Elem
*
final_toks
=
toks_
.
GetList
();
BaseFloat
infinity
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
BaseFloat
best_cost
=
infinity
,
best_cost_with_final
=
infinity
;
while
(
final_toks
!=
NULL
)
{
StateId
state
=
final_toks
->
key
;
Token
*
tok
=
final_toks
->
val
;
const
Elem
*
next
=
final_toks
->
tail
;
BaseFloat
final_cost
=
fst_
->
Final
(
state
).
Value
();
BaseFloat
cost
=
tok
->
tot_cost
,
cost_with_final
=
cost
+
final_cost
;
best_cost
=
std
::
min
(
cost
,
best_cost
);
best_cost_with_final
=
std
::
min
(
cost_with_final
,
best_cost_with_final
);
if
(
final_costs
!=
NULL
&&
final_cost
!=
infinity
)
(
*
final_costs
)[
tok
]
=
final_cost
;
final_toks
=
next
;
}
if
(
final_relative_cost
!=
NULL
)
{
if
(
best_cost
==
infinity
&&
best_cost_with_final
==
infinity
)
{
// Likely this will only happen if there are no tokens surviving.
// This seems the least bad way to handle it.
*
final_relative_cost
=
infinity
;
}
else
{
*
final_relative_cost
=
best_cost_with_final
-
best_cost
;
}
}
if
(
final_best_cost
!=
NULL
)
{
if
(
best_cost_with_final
!=
infinity
)
{
// final-state exists.
*
final_best_cost
=
best_cost_with_final
;
}
else
{
// no final-state exists.
*
final_best_cost
=
best_cost
;
}
}
}
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
AdvanceDecoding
(
DecodableInterface
*
decodable
,
int32
max_num_frames
)
{
if
(
std
::
is_same
<
FST
,
fst
::
Fst
<
fst
::
StdArc
>
>::
value
)
{
// if the type 'FST' is the FST base-class, then see if the FST type of fst_
// is actually VectorFst or ConstFst. If so, call the AdvanceDecoding()
// function after casting *this to the more specific type.
if
(
fst_
->
Type
()
==
"const"
)
{
LatticeFasterDecoderTpl
<
fst
::
ConstFst
<
fst
::
StdArc
>
,
Token
>
*
this_cast
=
reinterpret_cast
<
LatticeFasterDecoderTpl
<
fst
::
ConstFst
<
fst
::
StdArc
>
,
Token
>
*>
(
this
);
this_cast
->
AdvanceDecoding
(
decodable
,
max_num_frames
);
return
;
}
else
if
(
fst_
->
Type
()
==
"vector"
)
{
LatticeFasterDecoderTpl
<
fst
::
VectorFst
<
fst
::
StdArc
>
,
Token
>
*
this_cast
=
reinterpret_cast
<
LatticeFasterDecoderTpl
<
fst
::
VectorFst
<
fst
::
StdArc
>
,
Token
>
*>
(
this
);
this_cast
->
AdvanceDecoding
(
decodable
,
max_num_frames
);
return
;
}
}
KALDI_ASSERT
(
!
active_toks_
.
empty
()
&&
!
decoding_finalized_
&&
"You must call InitDecoding() before AdvanceDecoding"
);
int32
num_frames_ready
=
decodable
->
NumFramesReady
();
// num_frames_ready must be >= num_frames_decoded, or else
// the number of frames ready must have decreased (which doesn't
// make sense) or the decodable object changed between calls
// (which isn't allowed).
KALDI_ASSERT
(
num_frames_ready
>=
NumFramesDecoded
());
int32
target_frames_decoded
=
num_frames_ready
;
if
(
max_num_frames
>=
0
)
target_frames_decoded
=
std
::
min
(
target_frames_decoded
,
NumFramesDecoded
()
+
max_num_frames
);
while
(
NumFramesDecoded
()
<
target_frames_decoded
)
{
if
(
NumFramesDecoded
()
%
config_
.
prune_interval
==
0
)
{
PruneActiveTokens
(
config_
.
lattice_beam
*
config_
.
prune_scale
);
}
BaseFloat
cost_cutoff
=
ProcessEmitting
(
decodable
);
ProcessNonemitting
(
cost_cutoff
);
}
}
// FinalizeDecoding() is a version of PruneActiveTokens that we call
// (optionally) on the final frame. Takes into account the final-prob of
// tokens. This function used to be called PruneActiveTokensFinal().
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
FinalizeDecoding
()
{
int32
final_frame_plus_one
=
NumFramesDecoded
();
int32
num_toks_begin
=
num_toks_
;
// PruneForwardLinksFinal() prunes final frame (with final-probs), and
// sets decoding_finalized_.
PruneForwardLinksFinal
();
for
(
int32
f
=
final_frame_plus_one
-
1
;
f
>=
0
;
f
--
)
{
bool
b1
,
b2
;
// values not used.
BaseFloat
dontcare
=
0.0
;
// delta of zero means we must always update
PruneForwardLinks
(
f
,
&
b1
,
&
b2
,
dontcare
);
PruneTokensForFrame
(
f
+
1
);
}
PruneTokensForFrame
(
0
);
KALDI_VLOG
(
4
)
<<
"pruned tokens from "
<<
num_toks_begin
<<
" to "
<<
num_toks_
;
}
/// Gets the weight cutoff. Also counts the active tokens.
template
<
typename
FST
,
typename
Token
>
BaseFloat
LatticeFasterDecoderTpl
<
FST
,
Token
>::
GetCutoff
(
Elem
*
list_head
,
size_t
*
tok_count
,
BaseFloat
*
adaptive_beam
,
Elem
**
best_elem
)
{
BaseFloat
best_weight
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
// positive == high cost == bad.
size_t
count
=
0
;
if
(
config_
.
max_active
==
std
::
numeric_limits
<
int32
>::
max
()
&&
config_
.
min_active
==
0
)
{
for
(
Elem
*
e
=
list_head
;
e
!=
NULL
;
e
=
e
->
tail
,
count
++
)
{
BaseFloat
w
=
static_cast
<
BaseFloat
>
(
e
->
val
->
tot_cost
);
if
(
w
<
best_weight
)
{
best_weight
=
w
;
if
(
best_elem
)
*
best_elem
=
e
;
}
}
if
(
tok_count
!=
NULL
)
*
tok_count
=
count
;
if
(
adaptive_beam
!=
NULL
)
*
adaptive_beam
=
config_
.
beam
;
return
best_weight
+
config_
.
beam
;
}
else
{
tmp_array_
.
clear
();
for
(
Elem
*
e
=
list_head
;
e
!=
NULL
;
e
=
e
->
tail
,
count
++
)
{
BaseFloat
w
=
e
->
val
->
tot_cost
;
tmp_array_
.
push_back
(
w
);
if
(
w
<
best_weight
)
{
best_weight
=
w
;
if
(
best_elem
)
*
best_elem
=
e
;
}
}
if
(
tok_count
!=
NULL
)
*
tok_count
=
count
;
BaseFloat
beam_cutoff
=
best_weight
+
config_
.
beam
,
min_active_cutoff
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
(),
max_active_cutoff
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
KALDI_VLOG
(
6
)
<<
"Number of tokens active on frame "
<<
NumFramesDecoded
()
<<
" is "
<<
tmp_array_
.
size
();
if
(
tmp_array_
.
size
()
>
static_cast
<
size_t
>
(
config_
.
max_active
))
{
std
::
nth_element
(
tmp_array_
.
begin
(),
tmp_array_
.
begin
()
+
config_
.
max_active
,
tmp_array_
.
end
());
max_active_cutoff
=
tmp_array_
[
config_
.
max_active
];
}
if
(
max_active_cutoff
<
beam_cutoff
)
{
// max_active is tighter than beam.
if
(
adaptive_beam
)
*
adaptive_beam
=
max_active_cutoff
-
best_weight
+
config_
.
beam_delta
;
return
max_active_cutoff
;
}
if
(
tmp_array_
.
size
()
>
static_cast
<
size_t
>
(
config_
.
min_active
))
{
if
(
config_
.
min_active
==
0
)
{
min_active_cutoff
=
best_weight
;
}
else
{
std
::
nth_element
(
tmp_array_
.
begin
(),
tmp_array_
.
begin
()
+
config_
.
min_active
,
tmp_array_
.
size
()
>
static_cast
<
size_t
>
(
config_
.
max_active
)
?
tmp_array_
.
begin
()
+
config_
.
max_active
:
tmp_array_
.
end
());
min_active_cutoff
=
tmp_array_
[
config_
.
min_active
];
}
}
if
(
min_active_cutoff
>
beam_cutoff
)
{
// min_active is looser than beam.
if
(
adaptive_beam
)
*
adaptive_beam
=
min_active_cutoff
-
best_weight
+
config_
.
beam_delta
;
return
min_active_cutoff
;
}
else
{
*
adaptive_beam
=
config_
.
beam
;
return
beam_cutoff
;
}
}
}
template
<
typename
FST
,
typename
Token
>
BaseFloat
LatticeFasterDecoderTpl
<
FST
,
Token
>::
ProcessEmitting
(
DecodableInterface
*
decodable
)
{
KALDI_ASSERT
(
active_toks_
.
size
()
>
0
);
int32
frame
=
active_toks_
.
size
()
-
1
;
// frame is the frame-index
// (zero-based) used to get likelihoods
// from the decodable object.
active_toks_
.
resize
(
active_toks_
.
size
()
+
1
);
Elem
*
final_toks
=
toks_
.
Clear
();
// analogous to swapping prev_toks_ / cur_toks_
// in simple-decoder.h. Removes the Elems from
// being indexed in the hash in toks_.
Elem
*
best_elem
=
NULL
;
BaseFloat
adaptive_beam
;
size_t
tok_cnt
;
BaseFloat
cur_cutoff
=
GetCutoff
(
final_toks
,
&
tok_cnt
,
&
adaptive_beam
,
&
best_elem
);
KALDI_VLOG
(
6
)
<<
"Adaptive beam on frame "
<<
NumFramesDecoded
()
<<
" is "
<<
adaptive_beam
;
PossiblyResizeHash
(
tok_cnt
);
// This makes sure the hash is always big enough.
BaseFloat
next_cutoff
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
// pruning "online" before having seen all tokens
BaseFloat
cost_offset
=
0.0
;
// Used to keep probabilities in a good
// dynamic range.
// First process the best token to get a hopefully
// reasonably tight bound on the next cutoff. The only
// products of the next block are "next_cutoff" and "cost_offset".
if
(
best_elem
)
{
StateId
state
=
best_elem
->
key
;
Token
*
tok
=
best_elem
->
val
;
cost_offset
=
-
tok
->
tot_cost
;
for
(
fst
::
ArcIterator
<
FST
>
aiter
(
*
fst_
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
if
(
arc
.
ilabel
!=
0
)
{
// propagate..
BaseFloat
new_weight
=
arc
.
weight
.
Value
()
+
cost_offset
-
decodable
->
LogLikelihood
(
frame
,
arc
.
ilabel
)
+
tok
->
tot_cost
;
if
(
state
!=
arc
.
nextstate
)
{
new_weight
+=
config_
.
length_penalty
;
}
if
(
new_weight
+
adaptive_beam
<
next_cutoff
)
next_cutoff
=
new_weight
+
adaptive_beam
;
}
}
}
// Store the offset on the acoustic likelihoods that we're applying.
// Could just do cost_offsets_.push_back(cost_offset), but we
// do it this way as it's more robust to future code changes.
cost_offsets_
.
resize
(
frame
+
1
,
0.0
);
cost_offsets_
[
frame
]
=
cost_offset
;
// the tokens are now owned here, in final_toks, and the hash is empty.
// 'owned' is a complex thing here; the point is we need to call DeleteElem
// on each elem 'e' to let toks_ know we're done with them.
for
(
Elem
*
e
=
final_toks
,
*
e_tail
;
e
!=
NULL
;
e
=
e_tail
)
{
// loop this way because we delete "e" as we go.
StateId
state
=
e
->
key
;
Token
*
tok
=
e
->
val
;
if
(
tok
->
tot_cost
<=
cur_cutoff
)
{
for
(
fst
::
ArcIterator
<
FST
>
aiter
(
*
fst_
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
if
(
arc
.
ilabel
!=
0
)
{
// propagate..
BaseFloat
ac_cost
=
cost_offset
-
decodable
->
LogLikelihood
(
frame
,
arc
.
ilabel
),
graph_cost
=
arc
.
weight
.
Value
();
if
(
state
!=
arc
.
nextstate
)
{
graph_cost
+=
config_
.
length_penalty
;
}
BaseFloat
cur_cost
=
tok
->
tot_cost
,
tot_cost
=
cur_cost
+
ac_cost
+
graph_cost
;
if
(
tot_cost
>=
next_cutoff
)
continue
;
else
if
(
tot_cost
+
adaptive_beam
<
next_cutoff
)
next_cutoff
=
tot_cost
+
adaptive_beam
;
// prune by best current token
// Note: the frame indexes into active_toks_ are one-based,
// hence the + 1.
Elem
*
e_next
=
FindOrAddToken
(
arc
.
nextstate
,
frame
+
1
,
tot_cost
,
tok
,
NULL
);
// NULL: no change indicator needed
bool
is_start_boundary
=
false
;
bool
is_end_boundary
=
false
;
float
context_score
=
0
;
if
(
context_graph_
)
{
if
(
arc
.
olabel
==
0
)
{
e_next
->
val
->
context_state
=
tok
->
context_state
;
}
else
{
e_next
->
val
->
context_state
=
context_graph_
->
GetNextState
(
tok
->
context_state
,
arc
.
olabel
,
&
context_score
,
&
is_start_boundary
,
&
is_end_boundary
);
graph_cost
-=
context_score
;
}
}
// Add ForwardLink from tok to next_tok (put on head of list
// tok->links)
tok
->
links
=
new
ForwardLinkT
(
e_next
->
val
,
arc
.
ilabel
,
arc
.
olabel
,
graph_cost
,
ac_cost
,
is_start_boundary
,
is_end_boundary
,
tok
->
links
);
tok
->
links
->
context_score
=
context_score
;
}
}
// for all arcs
}
e_tail
=
e
->
tail
;
toks_
.
Delete
(
e
);
// delete Elem
}
return
next_cutoff
;
}
// static inline
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
DeleteForwardLinks
(
Token
*
tok
)
{
ForwardLinkT
*
l
=
tok
->
links
,
*
m
;
while
(
l
!=
NULL
)
{
m
=
l
->
next
;
delete
l
;
l
=
m
;
}
tok
->
links
=
NULL
;
}
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
ProcessNonemitting
(
BaseFloat
cutoff
)
{
KALDI_ASSERT
(
!
active_toks_
.
empty
());
int32
frame
=
static_cast
<
int32
>
(
active_toks_
.
size
())
-
2
;
// Note: "frame" is the time-index we just processed, or -1 if
// we are processing the nonemitting transitions before the
// first frame (called from InitDecoding()).
// Processes nonemitting arcs for one frame. Propagates within toks_.
// Note-- this queue structure is not very optimal as
// it may cause us to process states unnecessarily (e.g. more than once),
// but in the baseline code, turning this vector into a set to fix this
// problem did not improve overall speed.
KALDI_ASSERT
(
queue_
.
empty
());
if
(
toks_
.
GetList
()
==
NULL
)
{
if
(
!
warned_
)
{
KALDI_WARN
<<
"Error, no surviving tokens: frame is "
<<
frame
;
warned_
=
true
;
}
}
int
before
=
0
,
after
=
0
;
for
(
const
Elem
*
e
=
toks_
.
GetList
();
e
!=
NULL
;
e
=
e
->
tail
)
{
StateId
state
=
e
->
key
;
if
(
fst_
->
NumInputEpsilons
(
state
)
!=
0
)
queue_
.
push_back
(
e
);
++
before
;
}
while
(
!
queue_
.
empty
())
{
++
after
;
const
Elem
*
e
=
queue_
.
back
();
queue_
.
pop_back
();
StateId
state
=
e
->
key
;
Token
*
tok
=
e
->
val
;
// would segfault if e is a NULL pointer but this can't happen.
BaseFloat
cur_cost
=
tok
->
tot_cost
;
if
(
cur_cost
>=
cutoff
)
// Don't bother processing successors.
continue
;
// If "tok" has any existing forward links, delete them,
// because we're about to regenerate them. This is a kind
// of non-optimality (remember, this is the simple decoder),
// but since most states are emitting it's not a huge issue.
DeleteForwardLinks
(
tok
);
// necessary when re-visiting
tok
->
links
=
NULL
;
for
(
fst
::
ArcIterator
<
FST
>
aiter
(
*
fst_
,
state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
Arc
&
arc
=
aiter
.
Value
();
if
(
arc
.
ilabel
==
0
)
{
// propagate nonemitting only...
BaseFloat
graph_cost
=
arc
.
weight
.
Value
(),
tot_cost
=
cur_cost
+
graph_cost
;
if
(
tot_cost
<
cutoff
)
{
bool
changed
;
Elem
*
e_new
=
FindOrAddToken
(
arc
.
nextstate
,
frame
+
1
,
tot_cost
,
tok
,
&
changed
);
bool
is_start_boundary
=
false
;
bool
is_end_boundary
=
false
;
float
context_score
=
0
;
if
(
context_graph_
)
{
if
(
arc
.
olabel
==
0
)
{
e_new
->
val
->
context_state
=
tok
->
context_state
;
}
else
{
e_new
->
val
->
context_state
=
context_graph_
->
GetNextState
(
tok
->
context_state
,
arc
.
olabel
,
&
context_score
,
&
is_start_boundary
,
&
is_end_boundary
);
graph_cost
-=
context_score
;
}
}
tok
->
links
=
new
ForwardLinkT
(
e_new
->
val
,
0
,
arc
.
olabel
,
graph_cost
,
0
,
is_start_boundary
,
is_end_boundary
,
tok
->
links
);
tok
->
links
->
context_score
=
context_score
;
// "changed" tells us whether the new token has a different
// cost from before, or is new [if so, add into queue].
if
(
changed
&&
fst_
->
NumInputEpsilons
(
arc
.
nextstate
)
!=
0
)
queue_
.
push_back
(
e_new
);
}
}
}
// for all arcs
}
// while queue not empty
KALDI_VLOG
(
3
)
<<
"ProcessNonemitting "
<<
before
<<
" "
<<
after
;
}
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
DeleteElems
(
Elem
*
list
)
{
for
(
Elem
*
e
=
list
,
*
e_tail
;
e
!=
NULL
;
e
=
e_tail
)
{
e_tail
=
e
->
tail
;
toks_
.
Delete
(
e
);
}
}
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
ClearActiveTokens
()
{
// a cleanup routine, at utt end/begin
for
(
size_t
i
=
0
;
i
<
active_toks_
.
size
();
i
++
)
{
// Delete all tokens alive on this frame, and any forward
// links they may have.
for
(
Token
*
tok
=
active_toks_
[
i
].
toks
;
tok
!=
NULL
;)
{
DeleteForwardLinks
(
tok
);
Token
*
next_tok
=
tok
->
next
;
delete
tok
;
num_toks_
--
;
tok
=
next_tok
;
}
}
active_toks_
.
clear
();
KALDI_ASSERT
(
num_toks_
==
0
);
}
// static
template
<
typename
FST
,
typename
Token
>
void
LatticeFasterDecoderTpl
<
FST
,
Token
>::
TopSortTokens
(
Token
*
tok_list
,
std
::
vector
<
Token
*>
*
topsorted_list
)
{
unordered_map
<
Token
*
,
int32
>
token2pos
;
using
std
::
unordered_set
;
typedef
typename
unordered_map
<
Token
*
,
int32
>::
iterator
IterType
;
int32
num_toks
=
0
;
for
(
Token
*
tok
=
tok_list
;
tok
!=
NULL
;
tok
=
tok
->
next
)
num_toks
++
;
int32
cur_pos
=
0
;
// We assign the tokens numbers num_toks - 1, ... , 2, 1, 0.
// This is likely to be in closer to topological order than
// if we had given them ascending order, because of the way
// new tokens are put at the front of the list.
for
(
Token
*
tok
=
tok_list
;
tok
!=
NULL
;
tok
=
tok
->
next
)
token2pos
[
tok
]
=
num_toks
-
++
cur_pos
;
unordered_set
<
Token
*>
reprocess
;
for
(
IterType
iter
=
token2pos
.
begin
();
iter
!=
token2pos
.
end
();
++
iter
)
{
Token
*
tok
=
iter
->
first
;
int32
pos
=
iter
->
second
;
for
(
ForwardLinkT
*
link
=
tok
->
links
;
link
!=
NULL
;
link
=
link
->
next
)
{
if
(
link
->
ilabel
==
0
)
{
// We only need to consider epsilon links, since non-epsilon links
// transition between frames and this function only needs to sort a list
// of tokens from a single frame.
IterType
following_iter
=
token2pos
.
find
(
link
->
next_tok
);
if
(
following_iter
!=
token2pos
.
end
())
{
// another token on this
// frame, so must consider it.
int32
next_pos
=
following_iter
->
second
;
if
(
next_pos
<
pos
)
{
// reassign the position of the next Token.
following_iter
->
second
=
cur_pos
++
;
reprocess
.
insert
(
link
->
next_tok
);
}
}
}
}
// In case we had previously assigned this token to be reprocessed, we can
// erase it from that set because it's "happy now" (we just processed it).
reprocess
.
erase
(
tok
);
}
size_t
max_loop
=
1000000
,
loop_count
;
// max_loop is to detect epsilon cycles.
for
(
loop_count
=
0
;
!
reprocess
.
empty
()
&&
loop_count
<
max_loop
;
++
loop_count
)
{
std
::
vector
<
Token
*>
reprocess_vec
;
for
(
typename
unordered_set
<
Token
*>::
iterator
iter
=
reprocess
.
begin
();
iter
!=
reprocess
.
end
();
++
iter
)
reprocess_vec
.
push_back
(
*
iter
);
reprocess
.
clear
();
for
(
typename
std
::
vector
<
Token
*>::
iterator
iter
=
reprocess_vec
.
begin
();
iter
!=
reprocess_vec
.
end
();
++
iter
)
{
Token
*
tok
=
*
iter
;
int32
pos
=
token2pos
[
tok
];
// Repeat the processing we did above (for comments, see above).
for
(
ForwardLinkT
*
link
=
tok
->
links
;
link
!=
NULL
;
link
=
link
->
next
)
{
if
(
link
->
ilabel
==
0
)
{
IterType
following_iter
=
token2pos
.
find
(
link
->
next_tok
);
if
(
following_iter
!=
token2pos
.
end
())
{
int32
next_pos
=
following_iter
->
second
;
if
(
next_pos
<
pos
)
{
following_iter
->
second
=
cur_pos
++
;
reprocess
.
insert
(
link
->
next_tok
);
}
}
}
}
}
}
KALDI_ASSERT
(
loop_count
<
max_loop
&&
"Epsilon loops exist in your decoding "
"graph (this is not allowed!)"
);
topsorted_list
->
clear
();
topsorted_list
->
resize
(
cur_pos
,
NULL
);
// create a list with NULLs in between.
for
(
IterType
iter
=
token2pos
.
begin
();
iter
!=
token2pos
.
end
();
++
iter
)
(
*
topsorted_list
)[
iter
->
second
]
=
iter
->
first
;
}
// Instantiate the template for the combination of token types and FST types
// that we'll need.
template
class
LatticeFasterDecoderTpl
<
fst
::
Fst
<
fst
::
StdArc
>,
decoder
::
StdToken
>
;
template
class
LatticeFasterDecoderTpl
<
fst
::
VectorFst
<
fst
::
StdArc
>,
decoder
::
StdToken
>
;
template
class
LatticeFasterDecoderTpl
<
fst
::
ConstFst
<
fst
::
StdArc
>,
decoder
::
StdToken
>
;
// template class LatticeFasterDecoderTpl<fst::ConstGrammarFst,
// decoder::StdToken>; template class
// LatticeFasterDecoderTpl<fst::VectorGrammarFst, decoder::StdToken>;
template
class
LatticeFasterDecoderTpl
<
fst
::
Fst
<
fst
::
StdArc
>,
decoder
::
BackpointerToken
>
;
template
class
LatticeFasterDecoderTpl
<
fst
::
VectorFst
<
fst
::
StdArc
>,
decoder
::
BackpointerToken
>
;
template
class
LatticeFasterDecoderTpl
<
fst
::
ConstFst
<
fst
::
StdArc
>,
decoder
::
BackpointerToken
>
;
// template class LatticeFasterDecoderTpl<fst::ConstGrammarFst,
// decoder::BackpointerToken>; template class
// LatticeFasterDecoderTpl<fst::VectorGrammarFst, decoder::BackpointerToken>;
}
// end namespace kaldi.
runtime/core/kaldi/decoder/lattice-faster-decoder.h
0 → 100644
View file @
764b3a75
// decoder/lattice-faster-decoder.h
// Copyright 2009-2013 Microsoft Corporation; Mirko Hannemann;
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2018 Zhehuai Chen
// 2021 Binbin Zhang, Zhendong Peng
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_DECODER_LATTICE_FASTER_DECODER_H_
#define KALDI_DECODER_LATTICE_FASTER_DECODER_H_
#include <limits>
#include <memory>
#include <unordered_map>
#include <vector>
#include "base/kaldi-common.h"
#include "decoder/context_graph.h"
#include "fst/fstlib.h"
#include "fstext/fstext-lib.h"
#include "itf/decodable-itf.h"
#include "lat/determinize-lattice-pruned.h"
#include "lat/kaldi-lattice.h"
#include "util/hash-list.h"
namespace
kaldi
{
struct
LatticeFasterDecoderConfig
{
BaseFloat
beam
;
int32
max_active
;
int32
min_active
;
BaseFloat
lattice_beam
;
int32
prune_interval
;
bool
determinize_lattice
;
// not inspected by this class... used in
// command-line program.
BaseFloat
beam_delta
;
BaseFloat
hash_ratio
;
// Note: we don't make prune_scale configurable on the command line, it's not
// a very important parameter. It affects the algorithm that prunes the
// tokens as we go.
BaseFloat
prune_scale
;
BaseFloat
length_penalty
;
// for balancing the del/ins ratio, suggested -3.0
// Most of the options inside det_opts are not actually queried by the
// LatticeFasterDecoder class itself, but by the code that calls it, for
// example in the function DecodeUtteranceLatticeFaster.
fst
::
DeterminizeLatticePhonePrunedOptions
det_opts
;
LatticeFasterDecoderConfig
()
:
beam
(
16.0
),
max_active
(
std
::
numeric_limits
<
int32
>::
max
()),
min_active
(
200
),
lattice_beam
(
10.0
),
prune_interval
(
25
),
determinize_lattice
(
true
),
beam_delta
(
0.5
),
hash_ratio
(
2.0
),
prune_scale
(
0.1
),
length_penalty
(
0.0
)
{}
void
Register
(
OptionsItf
*
opts
)
{
det_opts
.
Register
(
opts
);
opts
->
Register
(
"beam"
,
&
beam
,
"Decoding beam. Larger->slower, more accurate."
);
opts
->
Register
(
"max-active"
,
&
max_active
,
"Decoder max active states. Larger->slower; "
"more accurate"
);
opts
->
Register
(
"min-active"
,
&
min_active
,
"Decoder minimum #active states."
);
opts
->
Register
(
"lattice-beam"
,
&
lattice_beam
,
"Lattice generation beam. Larger->slower, "
"and deeper lattices"
);
opts
->
Register
(
"prune-interval"
,
&
prune_interval
,
"Interval (in frames) at "
"which to prune tokens"
);
opts
->
Register
(
"determinize-lattice"
,
&
determinize_lattice
,
"If true, "
"determinize the lattice (lattice-determinization, keeping only "
"best pdf-sequence for each word-sequence)."
);
opts
->
Register
(
"beam-delta"
,
&
beam_delta
,
"Increment used in decoding-- this "
"parameter is obscure and relates to a speedup in the way the "
"max-active constraint is applied. Larger is more accurate."
);
opts
->
Register
(
"hash-ratio"
,
&
hash_ratio
,
"Setting used in decoder to "
"control hash behavior"
);
}
void
Check
()
const
{
KALDI_ASSERT
(
beam
>
0.0
&&
max_active
>
1
&&
lattice_beam
>
0.0
&&
min_active
<=
max_active
&&
prune_interval
>
0
&&
beam_delta
>
0.0
&&
hash_ratio
>=
1.0
&&
prune_scale
>
0.0
&&
prune_scale
<
1.0
);
}
};
namespace
decoder
{
// We will template the decoder on the token type as well as the FST type; this
// is a mechanism so that we can use the same underlying decoder code for
// versions of the decoder that support quickly getting the best path
// (LatticeFasterOnlineDecoder, see lattice-faster-online-decoder.h) and also
// those that do not (LatticeFasterDecoder).
// ForwardLinks are the links from a token to a token on the next frame.
// or sometimes on the current frame (for input-epsilon links).
template
<
typename
Token
>
struct
ForwardLink
{
using
Label
=
fst
::
StdArc
::
Label
;
Token
*
next_tok
;
// the next token [or NULL if represents final-state]
Label
ilabel
;
// ilabel on arc
Label
olabel
;
// olabel on arc
BaseFloat
graph_cost
;
// graph cost of traversing arc (contains LM, etc.)
BaseFloat
acoustic_cost
;
// acoustic cost (pre-scaled) of traversing arc
bool
is_start_boundary
;
bool
is_end_boundary
;
float
context_score
;
ForwardLink
*
next
;
// next in singly-linked list of forward arcs (arcs
// in the state-level lattice) from a token.
inline
ForwardLink
(
Token
*
next_tok
,
Label
ilabel
,
Label
olabel
,
BaseFloat
graph_cost
,
BaseFloat
acoustic_cost
,
bool
is_start_boundary
,
bool
is_end_boundary
,
ForwardLink
*
next
)
:
next_tok
(
next_tok
),
ilabel
(
ilabel
),
olabel
(
olabel
),
graph_cost
(
graph_cost
),
acoustic_cost
(
acoustic_cost
),
is_start_boundary
(
is_start_boundary
),
is_end_boundary
(
is_end_boundary
),
context_score
(
0
),
next
(
next
)
{}
};
struct
StdToken
{
using
ForwardLinkT
=
ForwardLink
<
StdToken
>
;
using
Token
=
StdToken
;
// Standard token type for LatticeFasterDecoder. Each active HCLG
// (decoding-graph) state on each frame has one token.
// tot_cost is the total (LM + acoustic) cost from the beginning of the
// utterance up to this point. (but see cost_offset_, which is subtracted
// to keep it in a good numerical range).
BaseFloat
tot_cost
;
// exta_cost is >= 0. After calling PruneForwardLinks, this equals the
// minimum difference between the cost of the best path that this link is a
// part of, and the cost of the absolute best path, under the assumption that
// any of the currently active states at the decoding front may eventually
// succeed (e.g. if you were to take the currently active states one by one
// and compute this difference, and then take the minimum).
BaseFloat
extra_cost
;
int
context_state
=
0
;
// 'links' is the head of singly-linked list of ForwardLinks, which is what we
// use for lattice generation.
ForwardLinkT
*
links
;
// 'next' is the next in the singly-linked list of tokens for this frame.
Token
*
next
;
// This function does nothing and should be optimized out; it's needed
// so we can share the regular LatticeFasterDecoderTpl code and the code
// for LatticeFasterOnlineDecoder that supports fast traceback.
inline
void
SetBackpointer
(
Token
*
backpointer
)
{}
// This constructor just ignores the 'backpointer' argument. That argument is
// needed so that we can use the same decoder code for LatticeFasterDecoderTpl
// and LatticeFasterOnlineDecoderTpl (which needs backpointers to support a
// fast way to obtain the best path).
inline
StdToken
(
BaseFloat
tot_cost
,
BaseFloat
extra_cost
,
ForwardLinkT
*
links
,
Token
*
next
,
Token
*
backpointer
)
:
tot_cost
(
tot_cost
),
extra_cost
(
extra_cost
),
links
(
links
),
context_state
(
0
),
next
(
next
)
{}
};
struct
BackpointerToken
{
using
ForwardLinkT
=
ForwardLink
<
BackpointerToken
>
;
using
Token
=
BackpointerToken
;
// BackpointerToken is like Token but also
// Standard token type for LatticeFasterDecoder. Each active HCLG
// (decoding-graph) state on each frame has one token.
// tot_cost is the total (LM + acoustic) cost from the beginning of the
// utterance up to this point. (but see cost_offset_, which is subtracted
// to keep it in a good numerical range).
BaseFloat
tot_cost
;
// exta_cost is >= 0. After calling PruneForwardLinks, this equals
// the minimum difference between the cost of the best path, and the cost of
// this is on, and the cost of the absolute best path, under the assumption
// that any of the currently active states at the decoding front may
// eventually succeed (e.g. if you were to take the currently active states
// one by one and compute this difference, and then take the minimum).
BaseFloat
extra_cost
;
int
context_state
=
0
;
// 'links' is the head of singly-linked list of ForwardLinks, which is what we
// use for lattice generation.
ForwardLinkT
*
links
;
// 'next' is the next in the singly-linked list of tokens for this frame.
BackpointerToken
*
next
;
// Best preceding BackpointerToken (could be a on this frame, connected to
// this via an epsilon transition, or on a previous frame). This is only
// required for an efficient GetBestPath function in
// LatticeFasterOnlineDecoderTpl; it plays no part in the lattice generation
// (the "links" list is what stores the forward links, for that).
Token
*
backpointer
;
inline
void
SetBackpointer
(
Token
*
backpointer
)
{
this
->
backpointer
=
backpointer
;
}
inline
BackpointerToken
(
BaseFloat
tot_cost
,
BaseFloat
extra_cost
,
ForwardLinkT
*
links
,
Token
*
next
,
Token
*
backpointer
)
:
tot_cost
(
tot_cost
),
extra_cost
(
extra_cost
),
links
(
links
),
next
(
next
),
backpointer
(
backpointer
),
context_state
(
0
)
{}
};
}
// namespace decoder
/** This is the "normal" lattice-generating decoder.
See \ref lattices_generation \ref decoders_faster and \ref decoders_simple
for more information.
The decoder is templated on the FST type and the token type. The token type
will normally be StdToken, but also may be BackpointerToken which is to
support quick lookup of the current best path (see
lattice-faster-online-decoder.h)
The FST you invoke this decoder which is expected to equal
Fst::Fst<fst::StdArc>, a.k.a. StdFst, or GrammarFst. If you invoke it with
FST == StdFst and it notices that the actual FST type is
fst::VectorFst<fst::StdArc> or fst::ConstFst<fst::StdArc>, the decoder object
will internally cast itself to one that is templated on those more specific
types; this is an optimization for speed.
*/
template
<
typename
FST
,
typename
Token
=
decoder
::
StdToken
>
class
LatticeFasterDecoderTpl
{
public:
using
Arc
=
typename
FST
::
Arc
;
using
Label
=
typename
Arc
::
Label
;
using
StateId
=
typename
Arc
::
StateId
;
using
Weight
=
typename
Arc
::
Weight
;
using
ForwardLinkT
=
decoder
::
ForwardLink
<
Token
>
;
// Instantiate this class once for each thing you have to decode.
// This version of the constructor does not take ownership of
// 'fst'.
LatticeFasterDecoderTpl
(
const
FST
&
fst
,
const
LatticeFasterDecoderConfig
&
config
,
const
std
::
shared_ptr
<
wenet
::
ContextGraph
>
&
context_graph
);
// This version of the constructor takes ownership of the fst, and will delete
// it when this object is destroyed.
LatticeFasterDecoderTpl
(
const
LatticeFasterDecoderConfig
&
config
,
FST
*
fst
);
void
SetOptions
(
const
LatticeFasterDecoderConfig
&
config
)
{
config_
=
config
;
}
const
LatticeFasterDecoderConfig
&
GetOptions
()
const
{
return
config_
;
}
~
LatticeFasterDecoderTpl
();
/// Decodes until there are no more frames left in the "decodable" object..
/// note, this may block waiting for input if the "decodable" object blocks.
/// Returns true if any kind of traceback is available (not necessarily from a
/// final state).
bool
Decode
(
DecodableInterface
*
decodable
);
/// says whether a final-state was active on the last frame. If it was not,
/// the lattice (or traceback) will end with states that are not final-states.
bool
ReachedFinal
()
const
{
return
FinalRelativeCost
()
!=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
}
/// Outputs an FST corresponding to the single best path through the lattice.
/// Returns true if result is nonempty (using the return status is deprecated,
/// it will become void). If "use_final_probs" is true AND we reached the
/// final-state of the graph then it will include those as final-probs, else
/// it will treat all final-probs as one. Note: this just calls
/// GetRawLattice() and figures out the shortest path.
bool
GetBestPath
(
Lattice
*
ofst
,
bool
use_final_probs
=
true
)
const
;
/// Outputs an FST corresponding to the raw, state-level
/// tracebacks. Returns true if result is nonempty.
/// If "use_final_probs" is true AND we reached the final-state
/// of the graph then it will include those as final-probs, else
/// it will treat all final-probs as one.
/// The raw lattice will be topologically sorted.
///
/// See also GetRawLatticePruned in lattice-faster-online-decoder.h,
/// which also supports a pruning beam, in case for some reason
/// you want it pruned tighter than the regular lattice beam.
/// We could put that here in future needed.
bool
GetRawLattice
(
Lattice
*
ofst
,
bool
use_final_probs
=
true
)
const
;
/// [Deprecated, users should now use GetRawLattice and determinize it
/// themselves, e.g. using DeterminizeLatticePhonePrunedWrapper].
/// Outputs an FST corresponding to the lattice-determinized
/// lattice (one path per word sequence). Returns true if result is
/// nonempty. If "use_final_probs" is true AND we reached the final-state of
/// the graph then it will include those as final-probs, else it will treat
/// all final-probs as one.
bool
GetLattice
(
CompactLattice
*
ofst
,
bool
use_final_probs
=
true
)
const
;
/// InitDecoding initializes the decoding, and should only be used if you
/// intend to call AdvanceDecoding(). If you call Decode(), you don't need to
/// call this. You can also call InitDecoding if you have already decoded an
/// utterance and want to start with a new utterance.
void
InitDecoding
();
/// This will decode until there are no more frames ready in the decodable
/// object. You can keep calling it each time more frames become available.
/// If max_num_frames is specified, it specifies the maximum number of frames
/// the function will decode before returning.
void
AdvanceDecoding
(
DecodableInterface
*
decodable
,
int32
max_num_frames
=
-
1
);
/// This function may be optionally called after AdvanceDecoding(), when you
/// do not plan to decode any further. It does an extra pruning step that
/// will help to prune the lattices output by GetLattice and (particularly)
/// GetRawLattice more completely, particularly toward the end of the
/// utterance. If you call this, you cannot call AdvanceDecoding again (it
/// will fail), and you cannot call GetLattice() and related functions with
/// use_final_probs = false. Used to be called PruneActiveTokensFinal().
void
FinalizeDecoding
();
/// FinalRelativeCost() serves the same purpose as ReachedFinal(), but gives
/// more information. It returns the difference between the best (final-cost
/// plus cost) of any token on the final frame, and the best cost of any token
/// on the final frame. If it is infinity it means no final-states were
/// present on the final frame. It will usually be nonnegative. If it not
/// too positive (e.g. < 5 is my first guess, but this is not tested) you can
/// take it as a good indication that we reached the final-state with
/// reasonable likelihood.
BaseFloat
FinalRelativeCost
()
const
;
// Returns the number of frames decoded so far. The value returned changes
// whenever we call ProcessEmitting().
inline
int32
NumFramesDecoded
()
const
{
return
active_toks_
.
size
()
-
1
;
}
protected:
// we make things protected instead of private, as code in
// LatticeFasterOnlineDecoderTpl, which inherits from this, also uses the
// internals.
// Deletes the elements of the singly linked list tok->links.
inline
static
void
DeleteForwardLinks
(
Token
*
tok
);
// head of per-frame list of Tokens (list is in topological order),
// and something saying whether we ever pruned it using PruneForwardLinks.
struct
TokenList
{
Token
*
toks
;
bool
must_prune_forward_links
;
bool
must_prune_tokens
;
TokenList
()
:
toks
(
NULL
),
must_prune_forward_links
(
true
),
must_prune_tokens
(
true
)
{}
};
using
Elem
=
typename
HashList
<
StateId
,
Token
*>::
Elem
;
// Equivalent to:
// struct Elem {
// StateId key;
// Token *val;
// Elem *tail;
// };
void
PossiblyResizeHash
(
size_t
num_toks
);
// FindOrAddToken either locates a token in hash of toks_, or if necessary
// inserts a new, empty token (i.e. with no forward links) for the current
// frame. [note: it's inserted if necessary into hash toks_ and also into the
// singly linked list of tokens active on this frame (whose head is at
// active_toks_[frame]). The frame_plus_one argument is the acoustic frame
// index plus one, which is used to index into the active_toks_ array.
// Returns the Token pointer. Sets "changed" (if non-NULL) to true if the
// token was newly created or the cost changed.
// If Token == StdToken, the 'backpointer' argument has no purpose (and will
// hopefully be optimized out).
inline
Elem
*
FindOrAddToken
(
StateId
state
,
int32
frame_plus_one
,
BaseFloat
tot_cost
,
Token
*
backpointer
,
bool
*
changed
);
// prunes outgoing links for all tokens in active_toks_[frame]
// it's called by PruneActiveTokens
// all links, that have link_extra_cost > lattice_beam are pruned
// delta is the amount by which the extra_costs must change
// before we set *extra_costs_changed = true.
// If delta is larger, we'll tend to go back less far
// toward the beginning of the file.
// extra_costs_changed is set to true if extra_cost was changed for any token
// links_pruned is set to true if any link in any token was pruned
void
PruneForwardLinks
(
int32
frame_plus_one
,
bool
*
extra_costs_changed
,
bool
*
links_pruned
,
BaseFloat
delta
);
// This function computes the final-costs for tokens active on the final
// frame. It outputs to final-costs, if non-NULL, a map from the Token*
// pointer to the final-prob of the corresponding state, for all Tokens
// that correspond to states that have final-probs. This map will be
// empty if there were no final-probs. It outputs to
// final_relative_cost, if non-NULL, the difference between the best
// forward-cost including the final-prob cost, and the best forward-cost
// without including the final-prob cost (this will usually be positive), or
// infinity if there were no final-probs. [c.f. FinalRelativeCost(), which
// outputs this quanitity]. It outputs to final_best_cost, if
// non-NULL, the lowest for any token t active on the final frame, of
// forward-cost[t] + final-cost[t], where final-cost[t] is the final-cost in
// the graph of the state corresponding to token t, or the best of
// forward-cost[t] if there were no final-probs active on the final frame.
// You cannot call this after FinalizeDecoding() has been called; in that
// case you should get the answer from class-member variables.
void
ComputeFinalCosts
(
unordered_map
<
Token
*
,
BaseFloat
>
*
final_costs
,
BaseFloat
*
final_relative_cost
,
BaseFloat
*
final_best_cost
)
const
;
// PruneForwardLinksFinal is a version of PruneForwardLinks that we call
// on the final frame. If there are final tokens active, it uses
// the final-probs for pruning, otherwise it treats all tokens as final.
void
PruneForwardLinksFinal
();
// Prune away any tokens on this frame that have no forward links.
// [we don't do this in PruneForwardLinks because it would give us
// a problem with dangling pointers].
// It's called by PruneActiveTokens if any forward links have been pruned
void
PruneTokensForFrame
(
int32
frame_plus_one
);
// Go backwards through still-alive tokens, pruning them if the
// forward+backward cost is more than lat_beam away from the best path. It's
// possible to prove that this is "correct" in the sense that we won't lose
// anything outside of lat_beam, regardless of what happens in the future.
// delta controls when it considers a cost to have changed enough to continue
// going backward and propagating the change. larger delta -> will recurse
// less far.
void
PruneActiveTokens
(
BaseFloat
delta
);
/// Gets the weight cutoff. Also counts the active tokens.
BaseFloat
GetCutoff
(
Elem
*
list_head
,
size_t
*
tok_count
,
BaseFloat
*
adaptive_beam
,
Elem
**
best_elem
);
/// Processes emitting arcs for one frame. Propagates from prev_toks_ to
/// cur_toks_. Returns the cost cutoff for subsequent ProcessNonemitting() to
/// use.
BaseFloat
ProcessEmitting
(
DecodableInterface
*
decodable
);
/// Processes nonemitting (epsilon) arcs for one frame. Called after
/// ProcessEmitting() on each frame. The cost cutoff is computed by the
/// preceding ProcessEmitting().
void
ProcessNonemitting
(
BaseFloat
cost_cutoff
);
// HashList defined in ../util/hash-list.h. It actually allows us to maintain
// more than one list (e.g. for current and previous frames), but only one of
// them at a time can be indexed by StateId. It is indexed by frame-index
// plus one, where the frame-index is zero-based, as used in decodable object.
// That is, the emitting probs of frame t are accounted for in tokens at
// toks_[t+1]. The zeroth frame is for nonemitting transition at the start of
// the graph.
HashList
<
StateId
,
Token
*>
toks_
;
std
::
vector
<
TokenList
>
active_toks_
;
// Lists of tokens, indexed by
// frame (members of TokenList are toks, must_prune_forward_links,
// must_prune_tokens).
std
::
vector
<
const
Elem
*>
queue_
;
// temp variable used in ProcessNonemitting,
std
::
vector
<
BaseFloat
>
tmp_array_
;
// used in GetCutoff.
// fst_ is a pointer to the FST we are decoding from.
const
FST
*
fst_
;
// delete_fst_ is true if the pointer fst_ needs to be deleted when this
// object is destroyed.
bool
delete_fst_
;
std
::
vector
<
BaseFloat
>
cost_offsets_
;
// This contains, for each
// frame, an offset that was added to the acoustic log-likelihoods on that
// frame in order to keep everything in a nice dynamic range i.e. close to
// zero, to reduce roundoff errors.
LatticeFasterDecoderConfig
config_
;
int32
num_toks_
;
// current total #toks allocated...
bool
warned_
;
/// decoding_finalized_ is true if someone called FinalizeDecoding(). [note,
/// calling this is optional]. If true, it's forbidden to decode more. Also,
/// if this is set, then the output of ComputeFinalCosts() is in the next
/// three variables. The reason we need to do this is that after
/// FinalizeDecoding() calls PruneTokensForFrame() for the final frame, some
/// of the tokens on the last frame are freed, so we free the list from toks_
/// to avoid having dangling pointers hanging around.
bool
decoding_finalized_
;
/// For the meaning of the next 3 variables, see the comment for
/// decoding_finalized_ above., and ComputeFinalCosts().
unordered_map
<
Token
*
,
BaseFloat
>
final_costs_
;
BaseFloat
final_relative_cost_
;
BaseFloat
final_best_cost_
;
std
::
shared_ptr
<
wenet
::
ContextGraph
>
context_graph_
=
nullptr
;
// There are various cleanup tasks... the toks_ structure contains
// singly linked lists of Token pointers, where Elem is the list type.
// It also indexes them in a hash, indexed by state (this hash is only
// maintained for the most recent frame). toks_.Clear()
// deletes them from the hash and returns the list of Elems. The
// function DeleteElems calls toks_.Delete(elem) for each elem in
// the list, which returns ownership of the Elem to the toks_ structure
// for reuse, but does not delete the Token pointer. The Token pointers
// are reference-counted and are ultimately deleted in PruneTokensForFrame,
// but are also linked together on each frame by their own linked-list,
// using the "next" pointer. We delete them manually.
void
DeleteElems
(
Elem
*
list
);
// This function takes a singly linked list of tokens for a single frame, and
// outputs a list of them in topological order (it will crash if no such order
// can be found, which will typically be due to decoding graphs with epsilon
// cycles, which are not allowed). Note: the output list may contain NULLs,
// which the caller should pass over; it just happens to be more efficient for
// the algorithm to output a list that contains NULLs.
static
void
TopSortTokens
(
Token
*
tok_list
,
std
::
vector
<
Token
*>
*
topsorted_list
);
void
ClearActiveTokens
();
KALDI_DISALLOW_COPY_AND_ASSIGN
(
LatticeFasterDecoderTpl
);
};
typedef
LatticeFasterDecoderTpl
<
fst
::
StdFst
,
decoder
::
StdToken
>
LatticeFasterDecoder
;
}
// end namespace kaldi.
#endif // KALDI_DECODER_LATTICE_FASTER_DECODER_H_
runtime/core/kaldi/decoder/lattice-faster-online-decoder.cc
0 → 100644
View file @
764b3a75
// decoder/lattice-faster-online-decoder.cc
// Copyright 2009-2012 Microsoft Corporation Mirko Hannemann
// 2013-2014 Johns Hopkins University (Author: Daniel Povey)
// 2014 Guoguo Chen
// 2014 IMSL, PKU-HKUST (author: Wei Shi)
// 2018 Zhehuai Chen
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
// see note at the top of lattice-faster-decoder.cc, about how to maintain this
// file in sync with lattice-faster-decoder.cc
#include <limits>
#include <queue>
#include <unordered_map>
#include <utility>
#include "decoder/lattice-faster-online-decoder.h"
namespace
kaldi
{
template
<
typename
FST
>
bool
LatticeFasterOnlineDecoderTpl
<
FST
>::
TestGetBestPath
(
bool
use_final_probs
)
const
{
Lattice
lat1
;
{
Lattice
raw_lat
;
this
->
GetRawLattice
(
&
raw_lat
,
use_final_probs
);
ShortestPath
(
raw_lat
,
&
lat1
);
}
Lattice
lat2
;
GetBestPath
(
&
lat2
,
use_final_probs
);
BaseFloat
delta
=
0.1
;
int32
num_paths
=
1
;
if
(
!
fst
::
RandEquivalent
(
lat1
,
lat2
,
num_paths
,
delta
,
rand
()))
{
KALDI_WARN
<<
"Best-path test failed"
;
return
false
;
}
else
{
return
true
;
}
}
// Outputs an FST corresponding to the single best path through the lattice.
template
<
typename
FST
>
bool
LatticeFasterOnlineDecoderTpl
<
FST
>::
GetBestPath
(
Lattice
*
olat
,
bool
use_final_probs
)
const
{
olat
->
DeleteStates
();
BaseFloat
final_graph_cost
;
BestPathIterator
iter
=
BestPathEnd
(
use_final_probs
,
&
final_graph_cost
);
if
(
iter
.
Done
())
return
false
;
// would have printed warning.
StateId
state
=
olat
->
AddState
();
olat
->
SetFinal
(
state
,
LatticeWeight
(
final_graph_cost
,
0.0
));
while
(
!
iter
.
Done
())
{
LatticeArc
arc
;
iter
=
TraceBackBestPath
(
iter
,
&
arc
);
arc
.
nextstate
=
state
;
StateId
new_state
=
olat
->
AddState
();
olat
->
AddArc
(
new_state
,
arc
);
state
=
new_state
;
}
olat
->
SetStart
(
state
);
return
true
;
}
template
<
typename
FST
>
typename
LatticeFasterOnlineDecoderTpl
<
FST
>::
BestPathIterator
LatticeFasterOnlineDecoderTpl
<
FST
>::
BestPathEnd
(
bool
use_final_probs
,
BaseFloat
*
final_cost_out
)
const
{
if
(
this
->
decoding_finalized_
&&
!
use_final_probs
)
KALDI_ERR
<<
"You cannot call FinalizeDecoding() and then call "
<<
"BestPathEnd() with use_final_probs == false"
;
KALDI_ASSERT
(
this
->
NumFramesDecoded
()
>
0
&&
"You cannot call BestPathEnd if no frames were decoded."
);
unordered_map
<
Token
*
,
BaseFloat
>
final_costs_local
;
const
unordered_map
<
Token
*
,
BaseFloat
>
&
final_costs
=
(
this
->
decoding_finalized_
?
this
->
final_costs_
:
final_costs_local
);
if
(
!
this
->
decoding_finalized_
&&
use_final_probs
)
this
->
ComputeFinalCosts
(
&
final_costs_local
,
NULL
,
NULL
);
// Singly linked list of tokens on last frame (access list through "next"
// pointer).
BaseFloat
best_cost
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
BaseFloat
best_final_cost
=
0
;
Token
*
best_tok
=
NULL
;
for
(
Token
*
tok
=
this
->
active_toks_
.
back
().
toks
;
tok
!=
NULL
;
tok
=
tok
->
next
)
{
BaseFloat
cost
=
tok
->
tot_cost
,
final_cost
=
0.0
;
if
(
use_final_probs
&&
!
final_costs
.
empty
())
{
// if we are instructed to use final-probs, and any final tokens were
// active on final frame, include the final-prob in the cost of the token.
typename
unordered_map
<
Token
*
,
BaseFloat
>::
const_iterator
iter
=
final_costs
.
find
(
tok
);
if
(
iter
!=
final_costs
.
end
())
{
final_cost
=
iter
->
second
;
cost
+=
final_cost
;
}
else
{
cost
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
}
}
if
(
cost
<
best_cost
)
{
best_cost
=
cost
;
best_tok
=
tok
;
best_final_cost
=
final_cost
;
}
}
if
(
best_tok
==
NULL
)
{
// this should not happen, and is likely a code error or
// caused by infinities in likelihoods, but I'm not making
// it a fatal error for now.
KALDI_WARN
<<
"No final token found."
;
}
if
(
final_cost_out
)
*
final_cost_out
=
best_final_cost
;
return
BestPathIterator
(
best_tok
,
this
->
NumFramesDecoded
()
-
1
);
}
template
<
typename
FST
>
typename
LatticeFasterOnlineDecoderTpl
<
FST
>::
BestPathIterator
LatticeFasterOnlineDecoderTpl
<
FST
>::
TraceBackBestPath
(
BestPathIterator
iter
,
LatticeArc
*
oarc
)
const
{
KALDI_ASSERT
(
!
iter
.
Done
()
&&
oarc
!=
NULL
);
Token
*
tok
=
static_cast
<
Token
*>
(
iter
.
tok
);
int32
cur_t
=
iter
.
frame
,
step_t
=
0
;
if
(
tok
->
backpointer
!=
NULL
)
{
// retrieve the correct forward link(with the best link cost)
BaseFloat
best_cost
=
std
::
numeric_limits
<
BaseFloat
>::
infinity
();
ForwardLinkT
*
link
;
for
(
link
=
tok
->
backpointer
->
links
;
link
!=
NULL
;
link
=
link
->
next
)
{
if
(
link
->
next_tok
==
tok
)
{
// this is a link to "tok"
BaseFloat
graph_cost
=
link
->
graph_cost
,
acoustic_cost
=
link
->
acoustic_cost
;
BaseFloat
cost
=
graph_cost
+
acoustic_cost
;
if
(
cost
<
best_cost
)
{
oarc
->
ilabel
=
link
->
ilabel
;
oarc
->
olabel
=
link
->
olabel
;
if
(
link
->
ilabel
!=
0
)
{
KALDI_ASSERT
(
static_cast
<
size_t
>
(
cur_t
)
<
this
->
cost_offsets_
.
size
());
acoustic_cost
-=
this
->
cost_offsets_
[
cur_t
];
step_t
=
-
1
;
}
else
{
step_t
=
0
;
}
oarc
->
weight
=
LatticeWeight
(
graph_cost
,
acoustic_cost
);
best_cost
=
cost
;
}
}
}
if
(
link
==
NULL
&&
best_cost
==
std
::
numeric_limits
<
BaseFloat
>::
infinity
())
{
// Did not find
// correct link.
KALDI_ERR
<<
"Error tracing best-path back (likely "
<<
"bug in token-pruning algorithm)"
;
}
}
else
{
oarc
->
ilabel
=
0
;
oarc
->
olabel
=
0
;
oarc
->
weight
=
LatticeWeight
::
One
();
// zero costs.
}
return
BestPathIterator
(
tok
->
backpointer
,
cur_t
+
step_t
);
}
template
<
typename
FST
>
bool
LatticeFasterOnlineDecoderTpl
<
FST
>::
GetRawLatticePruned
(
Lattice
*
ofst
,
bool
use_final_probs
,
BaseFloat
beam
)
const
{
typedef
LatticeArc
Arc
;
typedef
Arc
::
StateId
StateId
;
typedef
Arc
::
Weight
Weight
;
typedef
Arc
::
Label
Label
;
// Note: you can't use the old interface (Decode()) if you want to
// get the lattice with use_final_probs = false. You'd have to do
// InitDecoding() and then AdvanceDecoding().
if
(
this
->
decoding_finalized_
&&
!
use_final_probs
)
KALDI_ERR
<<
"You cannot call FinalizeDecoding() and then call "
<<
"GetRawLattice() with use_final_probs == false"
;
unordered_map
<
Token
*
,
BaseFloat
>
final_costs_local
;
const
unordered_map
<
Token
*
,
BaseFloat
>
&
final_costs
=
(
this
->
decoding_finalized_
?
this
->
final_costs_
:
final_costs_local
);
if
(
!
this
->
decoding_finalized_
&&
use_final_probs
)
this
->
ComputeFinalCosts
(
&
final_costs_local
,
NULL
,
NULL
);
ofst
->
DeleteStates
();
// num-frames plus one (since frames are one-based, and we have
// an extra frame for the start-state).
int32
num_frames
=
this
->
active_toks_
.
size
()
-
1
;
KALDI_ASSERT
(
num_frames
>
0
);
for
(
int32
f
=
0
;
f
<=
num_frames
;
f
++
)
{
if
(
this
->
active_toks_
[
f
].
toks
==
NULL
)
{
KALDI_WARN
<<
"No tokens active on frame "
<<
f
<<
": not producing lattice.
\n
"
;
return
false
;
}
}
unordered_map
<
Token
*
,
StateId
>
tok_map
;
std
::
queue
<
std
::
pair
<
Token
*
,
int32
>
>
tok_queue
;
// First initialize the queue and states. Put the initial state on the queue;
// this is the last token in the list active_toks_[0].toks.
for
(
Token
*
tok
=
this
->
active_toks_
[
0
].
toks
;
tok
!=
NULL
;
tok
=
tok
->
next
)
{
if
(
tok
->
next
==
NULL
)
{
tok_map
[
tok
]
=
ofst
->
AddState
();
ofst
->
SetStart
(
tok_map
[
tok
]);
std
::
pair
<
Token
*
,
int32
>
tok_pair
(
tok
,
0
);
// #frame = 0
tok_queue
.
push
(
tok_pair
);
}
}
// Next create states for "good" tokens
while
(
!
tok_queue
.
empty
())
{
std
::
pair
<
Token
*
,
int32
>
cur_tok_pair
=
tok_queue
.
front
();
tok_queue
.
pop
();
Token
*
cur_tok
=
cur_tok_pair
.
first
;
int32
cur_frame
=
cur_tok_pair
.
second
;
KALDI_ASSERT
(
cur_frame
>=
0
&&
cur_frame
<=
this
->
cost_offsets_
.
size
());
typename
unordered_map
<
Token
*
,
StateId
>::
const_iterator
iter
=
tok_map
.
find
(
cur_tok
);
KALDI_ASSERT
(
iter
!=
tok_map
.
end
());
StateId
cur_state
=
iter
->
second
;
for
(
ForwardLinkT
*
l
=
cur_tok
->
links
;
l
!=
NULL
;
l
=
l
->
next
)
{
Token
*
next_tok
=
l
->
next_tok
;
if
(
next_tok
->
extra_cost
<
beam
)
{
// so both the current and the next token are good; create the arc
int32
next_frame
=
l
->
ilabel
==
0
?
cur_frame
:
cur_frame
+
1
;
StateId
nextstate
;
if
(
tok_map
.
find
(
next_tok
)
==
tok_map
.
end
())
{
nextstate
=
tok_map
[
next_tok
]
=
ofst
->
AddState
();
tok_queue
.
push
(
std
::
pair
<
Token
*
,
int32
>
(
next_tok
,
next_frame
));
}
else
{
nextstate
=
tok_map
[
next_tok
];
}
BaseFloat
cost_offset
=
(
l
->
ilabel
!=
0
?
this
->
cost_offsets_
[
cur_frame
]
:
0
);
Arc
arc
(
l
->
ilabel
,
l
->
olabel
,
Weight
(
l
->
graph_cost
,
l
->
acoustic_cost
-
cost_offset
),
nextstate
);
ofst
->
AddArc
(
cur_state
,
arc
);
}
}
if
(
cur_frame
==
num_frames
)
{
if
(
use_final_probs
&&
!
final_costs
.
empty
())
{
typename
unordered_map
<
Token
*
,
BaseFloat
>::
const_iterator
iter
=
final_costs
.
find
(
cur_tok
);
if
(
iter
!=
final_costs
.
end
())
ofst
->
SetFinal
(
cur_state
,
LatticeWeight
(
iter
->
second
,
0
));
}
else
{
ofst
->
SetFinal
(
cur_state
,
LatticeWeight
::
One
());
}
}
}
return
(
ofst
->
NumStates
()
!=
0
);
}
// Instantiate the template for the FST types that we'll need.
template
class
LatticeFasterOnlineDecoderTpl
<
fst
::
Fst
<
fst
::
StdArc
>
>
;
template
class
LatticeFasterOnlineDecoderTpl
<
fst
::
VectorFst
<
fst
::
StdArc
>
>
;
template
class
LatticeFasterOnlineDecoderTpl
<
fst
::
ConstFst
<
fst
::
StdArc
>
>
;
}
// end namespace kaldi.
Prev
1
…
19
20
21
22
23
24
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment